diff --git a/.editorconfig b/.editorconfig index 1eadda334..c90b171f5 100644 --- a/.editorconfig +++ b/.editorconfig @@ -48,3 +48,7 @@ end_of_line = unset charset = unset trim_trailing_whitespace = unset insert_final_newline = unset + +[vendor/miniaudio/miniaudio.h] +trim_trailing_whitespace = unset +insert_final_newline = unset diff --git a/.github/workflows/build-linux-cross.yml b/.github/workflows/build-linux-cross.yml index dbd31e589..92dc41f9d 100644 --- a/.github/workflows/build-linux-cross.yml +++ b/.github/workflows/build-linux-cross.yml @@ -26,12 +26,12 @@ jobs: sudo apt-get install -y --no-install-recommends \ build-essential \ gcc-14-riscv64-linux-gnu \ - g++-14-riscv64-linux-gnu \ - libcurl4-openssl-dev:riscv64 + g++-14-riscv64-linux-gnu - name: Build run: | - cmake -B build -DCMAKE_BUILD_TYPE=Release \ + cmake -B build -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ -DGGML_OPENMP=OFF \ -DLLAMA_BUILD_EXAMPLES=ON \ -DLLAMA_BUILD_TOOLS=ON \ @@ -72,12 +72,12 @@ jobs: glslc \ gcc-14-riscv64-linux-gnu \ g++-14-riscv64-linux-gnu \ - libvulkan-dev:riscv64 \ - libcurl4-openssl-dev:riscv64 + libvulkan-dev:riscv64 - name: Build run: | - cmake -B build -DCMAKE_BUILD_TYPE=Release \ + cmake -B build -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ -DGGML_VULKAN=ON \ -DGGML_OPENMP=OFF \ -DLLAMA_BUILD_EXAMPLES=ON \ @@ -118,12 +118,12 @@ jobs: build-essential \ glslc \ crossbuild-essential-arm64 \ - libvulkan-dev:arm64 \ - libcurl4-openssl-dev:arm64 + libvulkan-dev:arm64 - name: Build run: | - cmake -B build -DCMAKE_BUILD_TYPE=Release \ + cmake -B build -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ -DGGML_VULKAN=ON \ -DGGML_OPENMP=OFF \ -DLLAMA_BUILD_EXAMPLES=ON \ @@ -163,12 +163,12 @@ jobs: sudo apt-get install -y --no-install-recommends \ build-essential \ gcc-14-powerpc64le-linux-gnu \ - g++-14-powerpc64le-linux-gnu \ - libcurl4-openssl-dev:ppc64el + g++-14-powerpc64le-linux-gnu - name: Build run: | - cmake -B build -DCMAKE_BUILD_TYPE=Release \ + cmake -B build -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ -DGGML_OPENMP=OFF \ -DLLAMA_BUILD_EXAMPLES=ON \ -DLLAMA_BUILD_TOOLS=ON \ @@ -209,12 +209,12 @@ jobs: glslc \ gcc-14-powerpc64le-linux-gnu \ g++-14-powerpc64le-linux-gnu \ - libvulkan-dev:ppc64el \ - libcurl4-openssl-dev:ppc64el + libvulkan-dev:ppc64el - name: Build run: | - cmake -B build -DCMAKE_BUILD_TYPE=Release \ + cmake -B build -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ -DGGML_VULKAN=ON \ -DGGML_OPENMP=OFF \ -DLLAMA_BUILD_EXAMPLES=ON \ diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index ed827bf70..65ed24465 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,4 +1,4 @@ -name: Create Release +name: Release on: workflow_dispatch: # allows manual triggering @@ -227,6 +227,69 @@ jobs: path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip name: llama-bin-ubuntu-vulkan-x64.zip + windows-cpu: + runs-on: windows-latest + + strategy: + matrix: + include: + - arch: 'x64' + - arch: 'arm64' + + steps: + - name: Clone + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: windows-latest-cmake-cpu-${{ matrix.arch }} + variant: ccache + evict-old-files: 1d + + - name: Install Ninja + run: | + choco install ninja + + - name: libCURL + id: get_libcurl + uses: ./.github/actions/windows-setup-curl + with: + architecture: ${{ matrix.arch == 'x64' && 'win64' || 'win64a' }} + + - name: Build + shell: cmd + env: + CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} + run: | + call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" ${{ matrix.arch }} + cmake -S . -B build -G "Ninja Multi-Config" ^ + -D CMAKE_TOOLCHAIN_FILE=cmake/${{ matrix.arch }}-windows-llvm.cmake ^ + -DGGML_NATIVE=OFF ^ + -DGGML_BACKEND_DL=ON ^ + -DGGML_CPU_ALL_VARIANTS=${{ matrix.arch == 'x64' && 'ON' || 'OFF' }} ^ + -DGGML_OPENMP=ON ^ + -DCURL_LIBRARY="%CURL_PATH%/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="%CURL_PATH%/include" ^ + ${{ env.CMAKE_ARGS }} + cmake --build build --config Release + + - name: Pack artifacts + id: pack_artifacts + env: + CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} + run: | + Copy-Item $env:CURL_PATH\bin\libcurl-${{ matrix.arch }}.dll .\build\bin\Release\ + Copy-Item "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Redist\MSVC\14.42.34433\debug_nonredist\${{ matrix.arch }}\Microsoft.VC143.OpenMP.LLVM\libomp140.${{ matrix.arch == 'x64' && 'x86_64' || 'aarch64' }}.dll" .\build\bin\Release\ + 7z a llama-bin-win-cpu-${{ matrix.arch }}.zip .\build\bin\Release\* + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + path: llama-bin-win-cpu-${{ matrix.arch }}.zip + name: llama-bin-win-cpu-${{ matrix.arch }}.zip + windows: runs-on: windows-latest @@ -237,52 +300,30 @@ jobs: strategy: matrix: include: - - build: 'cpu-x64' + - backend: 'vulkan' arch: 'x64' - defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_OPENMP=OFF' - #- build: 'openblas-x64' - # arch: 'x64' - # defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_OPENMP=OFF -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DBLAS_INCLUDE_DIRS="$env:RUNNER_TEMP/openblas/include" -DBLAS_LIBRARIES="$env:RUNNER_TEMP/openblas/lib/openblas.lib"' - - build: 'vulkan-x64' - arch: 'x64' - defines: '-DGGML_NATIVE=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_VULKAN=ON' - - build: 'cpu-arm64' - arch: 'arm64' - defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DGGML_NATIVE=OFF' - - build: 'opencl-adreno-arm64' + defines: '-DGGML_VULKAN=ON' + target: 'ggml-vulkan' + - backend: 'opencl-adreno' arch: 'arm64' defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" -DGGML_OPENCL=ON -DGGML_OPENCL_USE_ADRENO_KERNELS=ON' + target: 'ggml-opencl' steps: - name: Clone id: checkout uses: actions/checkout@v4 - with: - fetch-depth: 0 - name: ccache uses: hendrikmuhs/ccache-action@v1.2.16 with: - key: windows-latest-cmake-${{ matrix.build }} + key: windows-latest-cmake-${{ matrix.backend }}-${{ matrix.arch }} variant: ccache evict-old-files: 1d - - name: Download OpenBLAS - id: get_openblas - if: ${{ matrix.build == 'openblas-x64' }} - run: | - curl.exe -o $env:RUNNER_TEMP/openblas.zip -L "https://github.com/xianyi/OpenBLAS/releases/download/v${env:OPENBLAS_VERSION}/OpenBLAS-${env:OPENBLAS_VERSION}-x64.zip" - curl.exe -o $env:RUNNER_TEMP/OpenBLAS.LICENSE.txt -L "https://github.com/xianyi/OpenBLAS/raw/v${env:OPENBLAS_VERSION}/LICENSE" - mkdir $env:RUNNER_TEMP/openblas - tar.exe -xvf $env:RUNNER_TEMP/openblas.zip -C $env:RUNNER_TEMP/openblas - $vcdir = $(vswhere -latest -products * -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath) - $msvc = $(join-path $vcdir $('VC\Tools\MSVC\'+$(gc -raw $(join-path $vcdir 'VC\Auxiliary\Build\Microsoft.VCToolsVersion.default.txt')).Trim())) - $lib = $(join-path $msvc 'bin\Hostx64\x64\lib.exe') - & $lib /machine:x64 "/def:${env:RUNNER_TEMP}/openblas/lib/libopenblas.def" "/out:${env:RUNNER_TEMP}/openblas/lib/openblas.lib" /name:openblas.dll - - name: Install Vulkan SDK id: get_vulkan - if: ${{ matrix.build == 'vulkan-x64' }} + if: ${{ matrix.backend == 'vulkan' }} run: | curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/VulkanSDK-${env:VULKAN_VERSION}-Installer.exe" & "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install @@ -296,7 +337,7 @@ jobs: - name: Install OpenCL Headers and Libs id: install_opencl - if: ${{ matrix.build == 'opencl-adreno-arm64' }} + if: ${{ matrix.backend == 'opencl-adreno' && matrix.arch == 'arm64' }} run: | git clone https://github.com/KhronosGroup/OpenCL-Headers cd OpenCL-Headers @@ -314,46 +355,22 @@ jobs: -DCMAKE_INSTALL_PREFIX="$env:RUNNER_TEMP/opencl-arm64-release" cmake --build build-arm64-release --target install --config release - - name: libCURL - id: get_libcurl - uses: ./.github/actions/windows-setup-curl - with: - architecture: ${{ matrix.arch == 'x64' && 'win64' || 'win64a' }} - - name: Build id: cmake_build - env: - CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} run: | - cmake -S . -B build ${{ matrix.defines }} ` - -DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include" ` - ${{ env.CMAKE_ARGS }} - cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS} - - - name: Add libopenblas.dll - id: add_libopenblas_dll - if: ${{ matrix.build == 'openblas-x64' }} - run: | - cp $env:RUNNER_TEMP/openblas/bin/libopenblas.dll ./build/bin/Release/openblas.dll - cp $env:RUNNER_TEMP/OpenBLAS.LICENSE.txt ./build/bin/Release/OpenBLAS-${env:OPENBLAS_VERSION}.txt - - - name: Determine tag name - id: tag - uses: ./.github/actions/get-tag-name + cmake -S . -B build ${{ matrix.defines }} -DGGML_NATIVE=OFF -DGGML_CPU=OFF -DGGML_BACKEND_DL=ON -DLLAMA_CURL=OFF + cmake --build build --config Release --target ${{ matrix.target }} - name: Pack artifacts id: pack_artifacts - env: - CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} run: | - Copy-Item $env:CURL_PATH\bin\libcurl-${{ matrix.arch }}.dll .\build\bin\Release\ - 7z a llama-${{ steps.tag.outputs.name }}-bin-win-${{ matrix.build }}.zip .\build\bin\Release\* + 7z a llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip .\build\bin\Release\${{ matrix.target }}.dll - name: Upload artifacts uses: actions/upload-artifact@v4 with: - path: llama-${{ steps.tag.outputs.name }}-bin-win-${{ matrix.build }}.zip - name: llama-bin-win-${{ matrix.build }}.zip + path: llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip + name: llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip windows-cuda: runs-on: windows-2019 @@ -366,8 +383,6 @@ jobs: - name: Clone id: checkout uses: actions/checkout@v4 - with: - fetch-depth: 0 - name: Install ccache uses: hendrikmuhs/ccache-action@v1.2.16 @@ -386,45 +401,30 @@ jobs: run: | choco install ninja - - name: libCURL - id: get_libcurl - uses: ./.github/actions/windows-setup-curl - - name: Build id: cmake_build shell: cmd - env: - CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} run: | call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat" cmake -S . -B build -G "Ninja Multi-Config" ^ - -DGGML_NATIVE=OFF ^ -DGGML_BACKEND_DL=ON ^ - -DGGML_CPU_ALL_VARIANTS=ON ^ + -DGGML_NATIVE=OFF ^ + -DGGML_CPU=OFF ^ -DGGML_CUDA=ON ^ - -DCURL_LIBRARY="%CURL_PATH%/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="%CURL_PATH%/include" ^ - ${{ env.CMAKE_ARGS }} + -DLLAMA_CURL=OFF set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1 - cmake --build build --config Release -j %NINJA_JOBS% -t ggml - cmake --build build --config Release - - - name: Determine tag name - id: tag - uses: ./.github/actions/get-tag-name + cmake --build build --config Release -j %NINJA_JOBS% --target ggml-cuda - name: Pack artifacts id: pack_artifacts - env: - CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} run: | - cp $env:CURL_PATH\bin\libcurl-x64.dll .\build\bin\Release\libcurl-x64.dll - 7z a llama-${{ steps.tag.outputs.name }}-bin-win-cuda${{ matrix.cuda }}-x64.zip .\build\bin\Release\* + 7z a llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip .\build\bin\Release\ggml-cuda.dll - name: Upload artifacts uses: actions/upload-artifact@v4 with: - path: llama-${{ steps.tag.outputs.name }}-bin-win-cuda${{ matrix.cuda }}-x64.zip - name: llama-bin-win-cuda${{ matrix.cuda }}-x64.zip + path: llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip + name: llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip - name: Copy and pack Cuda runtime run: | @@ -432,13 +432,13 @@ jobs: $dst='.\build\bin\cudart\' robocopy "${{env.CUDA_PATH}}\bin" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll robocopy "${{env.CUDA_PATH}}\lib" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll - 7z a cudart-llama-bin-win-cuda${{ matrix.cuda }}-x64.zip $dst\* + 7z a cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip $dst\* - name: Upload Cuda runtime uses: actions/upload-artifact@v4 with: - path: cudart-llama-bin-win-cuda${{ matrix.cuda }}-x64.zip - name: cudart-llama-bin-win-cuda${{ matrix.cuda }}-x64.zip + path: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip + name: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip windows-sycl: runs-on: windows-latest @@ -451,12 +451,11 @@ jobs: WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/7cd9bba0-7aab-4e30-b3ae-2221006a4a05/intel-oneapi-base-toolkit-2025.1.1.34_offline.exe WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel:intel.oneapi.win.dnnl:intel.oneapi.win.tbb.devel ONEAPI_ROOT: "C:/Program Files (x86)/Intel/oneAPI" + steps: - name: Clone id: checkout uses: actions/checkout@v4 - with: - fetch-depth: 0 - name: ccache uses: hendrikmuhs/ccache-action@v1.2.16 @@ -469,15 +468,18 @@ jobs: run: | scripts/install-oneapi.bat $WINDOWS_BASEKIT_URL $WINDOWS_DPCPP_MKL - # TODO: add libcurl support ; we will also need to modify win-build-sycl.bat to accept user-specified args - - name: Build id: cmake_build - run: examples/sycl/win-build-sycl.bat - - - name: Determine tag name - id: tag - uses: ./.github/actions/get-tag-name + shell: cmd + run: | + call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force + cmake -G "Ninja" -B build ^ + -DCMAKE_C_COMPILER=cl -DCMAKE_CXX_COMPILER=icx ^ + -DCMAKE_BUILD_TYPE=Release ^ + -DGGML_BACKEND_DL=ON -DBUILD_SHARED_LIBS=ON ^ + -DGGML_CPU=OFF -DGGML_SYCL=ON ^ + -DLLAMA_CURL=OFF + cmake --build build --target ggml-sycl -j - name: Build the release package id: pack_artifacts @@ -502,12 +504,12 @@ jobs: cp "${{ env.ONEAPI_ROOT }}/tbb/latest/bin/tbb12.dll" ./build/bin echo "cp oneAPI running time dll files to ./build/bin done" - 7z a llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip ./build/bin/* + 7z a llama-bin-win-sycl-x64.zip ./build/bin/* - name: Upload the release package uses: actions/upload-artifact@v4 with: - path: llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip + path: llama-bin-win-sycl-x64.zip name: llama-bin-win-sycl-x64.zip windows-hip: @@ -515,14 +517,14 @@ jobs: strategy: matrix: - gpu_target: [gfx1100, gfx1101, gfx1030] + include: + - name: "radeon" + gpu_targets: "gfx1100;gfx1101;gfx1102;gfx1030;gfx1031;gfx1032" steps: - name: Clone id: checkout uses: actions/checkout@v4 - with: - fetch-depth: 0 - name: Clone rocWMMA repository id: clone_rocwmma @@ -532,7 +534,7 @@ jobs: - name: ccache uses: hendrikmuhs/ccache-action@v1.2.16 with: - key: windows-latest-cmake-hip-release + key: windows-latest-cmake-hip-${{ matrix.name }}-x64 evict-old-files: 1d - name: Install @@ -550,50 +552,39 @@ jobs: run: | & 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version - - name: libCURL - id: get_libcurl - uses: ./.github/actions/windows-setup-curl - - name: Build id: cmake_build - env: - CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} run: | $env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path) $env:CMAKE_PREFIX_PATH="${env:HIP_PATH}" cmake -G "Unix Makefiles" -B build -S . ` -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" ` -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" ` - -DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/rocwmma/library/include/" ` + -DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/rocwmma/library/include/ -Wno-ignored-attributes -Wno-nested-anon-types" ` -DCMAKE_BUILD_TYPE=Release ` - -DAMDGPU_TARGETS=${{ matrix.gpu_target }} ` + -DGGML_BACKEND_DL=ON ` + -DGGML_NATIVE=OFF ` + -DGGML_CPU=OFF ` + -DAMDGPU_TARGETS="${{ matrix.gpu_targets }}" ` -DGGML_HIP_ROCWMMA_FATTN=ON ` -DGGML_HIP=ON ` - -DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include" ` - ${{ env.CMAKE_ARGS }} - cmake --build build -j ${env:NUMBER_OF_PROCESSORS} + -DLLAMA_CURL=OFF + cmake --build build --target ggml-hip -j ${env:NUMBER_OF_PROCESSORS} md "build\bin\rocblas\library\" cp "${env:HIP_PATH}\bin\hipblas.dll" "build\bin\" cp "${env:HIP_PATH}\bin\rocblas.dll" "build\bin\" cp "${env:HIP_PATH}\bin\rocblas\library\*" "build\bin\rocblas\library\" - - name: Determine tag name - id: tag - uses: ./.github/actions/get-tag-name - - name: Pack artifacts id: pack_artifacts - env: - CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} run: | - cp $env:CURL_PATH\bin\libcurl-x64.dll .\build\bin\libcurl-x64.dll - 7z a llama-${{ steps.tag.outputs.name }}-bin-win-hip-x64-${{ matrix.gpu_target }}.zip .\build\bin\* + 7z a llama-bin-win-hip-${{ matrix.name }}-x64.zip .\build\bin\* - name: Upload artifacts uses: actions/upload-artifact@v4 with: - path: llama-${{ steps.tag.outputs.name }}-bin-win-hip-x64-${{ matrix.gpu_target }}.zip - name: llama-bin-win-hip-x64-${{ matrix.gpu_target }}.zip + path: llama-bin-win-hip-${{ matrix.name }}-x64.zip + name: llama-bin-win-hip-${{ matrix.name }}-x64.zip ios-xcode-build: runs-on: macos-latest @@ -655,14 +646,16 @@ jobs: runs-on: ubuntu-latest needs: - - ubuntu-22-cpu - - ubuntu-22-vulkan - windows + - windows-cpu - windows-cuda - windows-sycl - windows-hip + - ubuntu-22-cpu + - ubuntu-22-vulkan - macOS-arm64 - macOS-x64 + - ios-xcode-build steps: - name: Clone @@ -680,10 +673,43 @@ jobs: uses: actions/download-artifact@v4 with: path: ./artifact + merge-multiple: true - name: Move artifacts id: move_artifacts - run: mkdir -p ./artifact/release && mv ./artifact/*/*.zip ./artifact/release + run: | + mkdir -p release + + echo "Adding CPU backend files to existing zips..." + for arch in x64 arm64; do + cpu_zip="artifact/llama-bin-win-cpu-${arch}.zip" + temp_dir=$(mktemp -d) + echo "Extracting CPU backend for $arch..." + unzip "$cpu_zip" -d "$temp_dir" + + echo "Adding CPU files to $arch zips..." + for target_zip in artifact/llama-bin-win-*-${arch}.zip; do + if [[ "$target_zip" == "$cpu_zip" ]]; then + continue + fi + echo "Adding CPU backend to $(basename "$target_zip")" + realpath_target_zip=$(realpath "$target_zip") + (cd "$temp_dir" && zip -r "$realpath_target_zip" .) + done + + rm -rf "$temp_dir" + done + + echo "Renaming and moving zips to release..." + for zip_file in artifact/llama-bin-win-*.zip; do + base_name=$(basename "$zip_file" .zip) + zip_name="llama-${{ steps.tag.outputs.name }}-${base_name#llama-}.zip" + echo "Moving $zip_file to release/$zip_name" + mv "$zip_file" "release/$zip_name" + done + + echo "Moving other artifacts..." + mv -v artifact/*.zip release - name: Create release id: create_release @@ -702,7 +728,7 @@ jobs: const path = require('path'); const fs = require('fs'); const release_id = '${{ steps.create_release.outputs.id }}'; - for (let file of await fs.readdirSync('./artifact/release')) { + for (let file of await fs.readdirSync('./release')) { if (path.extname(file) === '.zip') { console.log('uploadReleaseAsset', file); await github.repos.uploadReleaseAsset({ @@ -710,7 +736,7 @@ jobs: repo: context.repo.repo, release_id: release_id, name: file, - data: await fs.readFileSync(`./artifact/release/${file}`) + data: await fs.readFileSync(`./release/${file}`) }); } } diff --git a/.github/workflows/winget.yml b/.github/workflows/winget.yml new file mode 100644 index 000000000..5c2861559 --- /dev/null +++ b/.github/workflows/winget.yml @@ -0,0 +1,42 @@ +name: Update Winget Package + +on: + workflow_dispatch: # allows manual triggering + schedule: + - cron: '28 5 * * *' # Update every day at 5:28 UTC + +jobs: + update: + name: Update Winget Package + runs-on: ubuntu-latest + + steps: + - name: Install cargo binstall + uses: cargo-bins/cargo-binstall@268643a6b5ea099f5718ee5cd3ff7dc89a5eb49b + + - name: Install komac + run: | + cargo binstall komac@2.11.2 -y + + - name: Find latest release + id: find_latest_release + uses: actions/github-script@v6 + with: + script: | + const { data: releases } = await github.rest.repos.listReleases({ + owner: context.repo.owner, + repo: context.repo.repo, + }); + console.log("Latest release:", releases[0].tag_name); + return releases[0].tag_name; + + - name: Update manifest + env: + VERSION: ${{ steps.find_latest_release.outputs.result }} + run: | + echo "Updating manifest..." + komac update --version ${{ env.VERSION }} \ + --urls "https://github.com/ggml-org/llama.cpp/releases/download/${{ env.VERSION }}/llama-${{ env.VERSION }}-bin-win-vulkan-x64.zip" \ + --token ${{ secrets.WINGET_GITHUB_TOKEN }} \ + --submit \ + ggml.llamacpp diff --git a/README.md b/README.md index d1cb8d833..576332bc5 100644 --- a/README.md +++ b/README.md @@ -130,6 +130,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
Bindings +- Python: [ddh0/easy-llama](https://github.com/ddh0/easy-llama) - Python: [abetlen/llama-cpp-python](https://github.com/abetlen/llama-cpp-python) - Go: [go-skynet/go-llama.cpp](https://github.com/go-skynet/go-llama.cpp) - Node.js: [withcatai/node-llama-cpp](https://github.com/withcatai/node-llama-cpp) @@ -580,3 +581,4 @@ $ echo "source ~/.llama-completion.bash" >> ~/.bashrc - [minja](https://github.com/google/minja) - Minimal Jinja parser in C++, used by various tools/examples - MIT License - [linenoise.cpp](./tools/run/linenoise.cpp/linenoise.cpp) - C++ library that provides readline-like line editing capabilities, used by `llama-run` - BSD 2-Clause License - [curl](https://curl.se/) - Client-side URL transfer library, used by various tools/examples - [CURL License](https://curl.se/docs/copyright.html) +- [miniaudio.h](https://github.com/mackron/miniaudio) - Single-header audio format decoder, used by multimodal subsystem - Public domain diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index a7ff3ac16..564af1448 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -58,19 +58,20 @@ add_library(${TARGET} STATIC arg.cpp arg.h base64.hpp + chat-parser.cpp + chat-parser.h chat.cpp chat.h common.cpp common.h console.cpp console.h + json-partial.cpp + json-partial.h json-schema-to-grammar.cpp - json.hpp llguidance.cpp log.cpp log.h - minja/chat-template.hpp - minja/minja.hpp ngram-cache.cpp ngram-cache.h regex-partial.cpp @@ -143,7 +144,7 @@ if (LLAMA_LLGUIDANCE) set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance ${LLGUIDANCE_PLATFORM_LIBS}) endif () -target_include_directories(${TARGET} PUBLIC .) +target_include_directories(${TARGET} PUBLIC . ../vendor) target_compile_features (${TARGET} PUBLIC cxx_std_17) target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads) diff --git a/common/arg.cpp b/common/arg.cpp index 997f732cc..cfa9878f9 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1,10 +1,11 @@ -#include "gguf.h" // for reading GGUF splits #include "arg.h" +#include "chat.h" #include "common.h" +#include "gguf.h" // for reading GGUF splits +#include "json-schema-to-grammar.h" #include "log.h" #include "sampling.h" -#include "chat.h" // fix problem with std::min and std::max #if defined(_WIN32) @@ -15,6 +16,9 @@ #include #endif +#define JSON_ASSERT GGML_ASSERT +#include + #include #include #include @@ -34,12 +38,10 @@ #include #endif -#include "json-schema-to-grammar.h" - using json = nlohmann::ordered_json; std::initializer_list mmproj_examples = { - LLAMA_EXAMPLE_LLAVA, + LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_SERVER, }; @@ -242,7 +244,56 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma } // download one single file from remote URL to local path -static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token) { +static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, bool offline) { + // Check if the file already exists locally + auto file_exists = std::filesystem::exists(path); + + // If the file exists, check its JSON metadata companion file. + std::string metadata_path = path + ".json"; + nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead + std::string etag; + std::string last_modified; + + if (file_exists) { + if (offline) { + LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str()); + return true; // skip verification/downloading + } + // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block). + std::ifstream metadata_in(metadata_path); + if (metadata_in.good()) { + try { + metadata_in >> metadata; + LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str()); + if (metadata.contains("etag") && metadata.at("etag").is_string()) { + etag = metadata.at("etag"); + } + if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) { + last_modified = metadata.at("lastModified"); + } + } catch (const nlohmann::json::exception & e) { + LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); + } + } + // if we cannot open the metadata file, we assume that the downloaded file is not valid (etag and last-modified are left empty, so we will download it again) + } else { + if (offline) { + LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str()); + return false; + } + LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); + } + + // Send a HEAD request to retrieve the etag and last-modified headers + struct common_load_model_from_url_headers { + std::string etag; + std::string last_modified; + }; + + common_load_model_from_url_headers headers; + bool head_request_ok = false; + bool should_download = !file_exists; // by default, we should download if the file does not exist + // Initialize libcurl curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); curl_slist_ptr http_headers; @@ -269,91 +320,47 @@ static bool common_download_file_single(const std::string & url, const std::stri curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); #endif - // Check if the file already exists locally - auto file_exists = std::filesystem::exists(path); + typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *); + auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t { + common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata; - // If the file exists, check its JSON metadata companion file. - std::string metadata_path = path + ".json"; - nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead - std::string etag; - std::string last_modified; + static std::regex header_regex("([^:]+): (.*)\r\n"); + static std::regex etag_regex("ETag", std::regex_constants::icase); + static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase); - if (file_exists) { - // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block). - std::ifstream metadata_in(metadata_path); - if (metadata_in.good()) { - try { - metadata_in >> metadata; - LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str()); - if (metadata.contains("etag") && metadata.at("etag").is_string()) { - etag = metadata.at("etag"); - } - if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) { - last_modified = metadata.at("lastModified"); - } - } catch (const nlohmann::json::exception & e) { - LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); + std::string header(buffer, n_items); + std::smatch match; + if (std::regex_match(header, match, header_regex)) { + const std::string & key = match[1]; + const std::string & value = match[2]; + if (std::regex_match(key, match, etag_regex)) { + headers->etag = value; + } else if (std::regex_match(key, match, last_modified_regex)) { + headers->last_modified = value; } } - // if we cannot open the metadata file, we assume that the downloaded file is not valid (etag and last-modified are left empty, so we will download it again) - } else { - LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); - } - - // Send a HEAD request to retrieve the etag and last-modified headers - struct common_load_model_from_url_headers { - std::string etag; - std::string last_modified; + return n_items; }; - common_load_model_from_url_headers headers; - bool head_request_ok = false; - bool should_download = !file_exists; // by default, we should download if the file does not exist + curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress + curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast(header_callback)); + curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers); - // get ETag to see if the remote file has changed - { - typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *); - auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t { - common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata; + // we only allow retrying once for HEAD requests + // this is for the use case of using running offline (no internet), retrying can be annoying + bool was_perform_successful = curl_perform_with_retry(url, curl.get(), 1, 0, "HEAD"); + if (!was_perform_successful) { + head_request_ok = false; + } - static std::regex header_regex("([^:]+): (.*)\r\n"); - static std::regex etag_regex("ETag", std::regex_constants::icase); - static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase); - - std::string header(buffer, n_items); - std::smatch match; - if (std::regex_match(header, match, header_regex)) { - const std::string & key = match[1]; - const std::string & value = match[2]; - if (std::regex_match(key, match, etag_regex)) { - headers->etag = value; - } else if (std::regex_match(key, match, last_modified_regex)) { - headers->last_modified = value; - } - } - return n_items; - }; - - curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb - curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress - curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast(header_callback)); - curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers); - - // we only allow retrying once for HEAD requests - // this is for the use case of using running offline (no internet), retrying can be annoying - bool was_perform_successful = curl_perform_with_retry(url, curl.get(), 1, 0, "HEAD"); - if (!was_perform_successful) { - head_request_ok = false; - } - - long http_code = 0; - curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code); - if (http_code == 200) { - head_request_ok = true; - } else { - LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code); - head_request_ok = false; - } + long http_code = 0; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code); + if (http_code == 200) { + head_request_ok = true; + } else { + LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code); + head_request_ok = false; } // if head_request_ok is false, we don't have the etag or last-modified headers @@ -460,12 +467,12 @@ static bool common_download_file_single(const std::string & url, const std::stri // download multiple files from remote URLs to local paths // the input is a vector of pairs -static bool common_download_file_multiple(const std::vector> & urls, const std::string & bearer_token) { +static bool common_download_file_multiple(const std::vector> & urls, const std::string & bearer_token, bool offline) { // Prepare download in parallel std::vector> futures_download; for (auto const & item : urls) { - futures_download.push_back(std::async(std::launch::async, [bearer_token](const std::pair & it) -> bool { - return common_download_file_single(it.first, it.second, bearer_token); + futures_download.push_back(std::async(std::launch::async, [bearer_token, offline](const std::pair & it) -> bool { + return common_download_file_single(it.first, it.second, bearer_token, offline); }, item)); } @@ -481,14 +488,15 @@ static bool common_download_file_multiple(const std::vector> common_remote_get_content(const std::string & * * Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files. */ -static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token) { +static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token, bool offline) { auto parts = string_split(hf_repo_with_tag, ':'); std::string tag = parts.size() > 1 ? parts.back() : "latest"; std::string hf_repo = parts[0]; @@ -638,20 +646,25 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_ long res_code = 0; std::string res_str; bool use_cache = false; - try { - auto res = common_remote_get_content(url, params); - res_code = res.first; - res_str = std::string(res.second.data(), res.second.size()); - } catch (const std::exception & e) { - LOG_WRN("error: failed to get manifest: %s\n", e.what()); - LOG_WRN("try reading from cache\n"); - // try to read from cache + if (!offline) { try { + auto res = common_remote_get_content(url, params); + res_code = res.first; + res_str = std::string(res.second.data(), res.second.size()); + } catch (const std::exception & e) { + LOG_WRN("error: failed to get manifest at %s: %s\n", url.c_str(), e.what()); + } + } + if (res_code == 0) { + if (std::filesystem::exists(cached_response_path)) { + LOG_WRN("trying to read manifest from cache: %s\n", cached_response_path.c_str()); res_str = read_file(cached_response_path); res_code = 200; use_cache = true; - } catch (const std::exception & e) { - throw std::runtime_error("error: failed to get manifest (check your internet connection)"); + } else { + throw std::runtime_error( + offline ? "error: failed to get manifest (offline mode)" + : "error: failed to get manifest (check your internet connection)"); } } std::string ggufFile; @@ -698,24 +711,25 @@ bool common_has_curl() { return false; } -static bool common_download_file_single(const std::string &, const std::string &, const std::string &) { +static bool common_download_file_single(const std::string &, const std::string &, const std::string &, bool) { LOG_ERR("error: built without CURL, cannot download model from internet\n"); return false; } -static bool common_download_file_multiple(const std::vector> &, const std::string &) { +static bool common_download_file_multiple(const std::vector> &, const std::string &, bool) { LOG_ERR("error: built without CURL, cannot download model from the internet\n"); return false; } static bool common_download_model( const common_params_model &, - const std::string &) { + const std::string &, + bool) { LOG_ERR("error: built without CURL, cannot download model from the internet\n"); return false; } -static struct common_hf_file_res common_get_hf_file(const std::string &, const std::string &) { +static struct common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool) { LOG_ERR("error: built without CURL, cannot download model from the internet\n"); return {}; } @@ -742,7 +756,8 @@ struct handle_model_result { static handle_model_result common_params_handle_model( struct common_params_model & model, const std::string & bearer_token, - const std::string & model_path_default) { + const std::string & model_path_default, + bool offline) { handle_model_result result; // handle pre-fill default model path and url based on hf_repo and hf_file { @@ -750,7 +765,7 @@ static handle_model_result common_params_handle_model( // short-hand to avoid specifying --hf-file -> default it to --model if (model.hf_file.empty()) { if (model.path.empty()) { - auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token); + auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token, offline); if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) { exit(1); // built without CURL, error message already printed } @@ -791,7 +806,7 @@ static handle_model_result common_params_handle_model( // then, download it if needed if (!model.url.empty()) { - bool ok = common_download_model(model, bearer_token); + bool ok = common_download_model(model, bearer_token, offline); if (!ok) { LOG_ERR("error: failed to download model from %s\n", model.url.c_str()); exit(1); @@ -934,7 +949,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context // handle model and download { - auto res = common_params_handle_model(params.model, params.hf_token, DEFAULT_MODEL_PATH); + auto res = common_params_handle_model(params.model, params.hf_token, DEFAULT_MODEL_PATH, params.offline); if (params.no_mmproj) { params.mmproj = {}; } else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) { @@ -944,12 +959,12 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context // only download mmproj if the current example is using it for (auto & ex : mmproj_examples) { if (ctx_arg.ex == ex) { - common_params_handle_model(params.mmproj, params.hf_token, ""); + common_params_handle_model(params.mmproj, params.hf_token, "", params.offline); break; } } - common_params_handle_model(params.speculative.model, params.hf_token, ""); - common_params_handle_model(params.vocoder.model, params.hf_token, ""); + common_params_handle_model(params.speculative.model, params.hf_token, "", params.offline); + common_params_handle_model(params.vocoder.model, params.hf_token, "", params.offline); } if (params.escape) { @@ -1333,9 +1348,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex )); add_opt(common_arg( {"--prio"}, "N", - string_format("set process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.cpuparams.priority), + string_format("set process/thread priority : low(-1), normal(0), medium(1), high(2), realtime(3) (default: %d)\n", params.cpuparams.priority), [](common_params & params, int prio) { - if (prio < 0 || prio > 3) { + if (prio < GGML_SCHED_PRIO_LOW || prio > GGML_SCHED_PRIO_REALTIME) { throw std::invalid_argument("invalid value"); } params.cpuparams.priority = (enum ggml_sched_priority) prio; @@ -2233,12 +2248,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_examples(mmproj_examples).set_env("LLAMA_ARG_NO_MMPROJ_OFFLOAD")); add_opt(common_arg( - {"--image"}, "FILE", - "path to an image file. use with multimodal models. Specify multiple times for batching", + {"--image", "--audio"}, "FILE", + "path to an image or audio file. use with multimodal models, can be repeated if you have multiple files\n", [](common_params & params, const std::string & value) { params.image.emplace_back(value); } - ).set_examples({LLAMA_EXAMPLE_LLAVA})); + ).set_examples({LLAMA_EXAMPLE_MTMD})); if (llama_supports_rpc()) { add_opt(common_arg( {"--rpc"}, "SERVERS", @@ -2848,15 +2863,24 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA")); add_opt(common_arg( {"--reasoning-format"}, "FORMAT", - "reasoning format (default: deepseek; allowed values: deepseek, none)\n" - "controls whether thought tags are extracted from the response, and in which format they're returned. 'none' leaves thoughts unparsed in `message.content`, 'deepseek' puts them in `message.reasoning_content` (for DeepSeek R1 & Command R7B only).\n" - "only supported for non-streamed responses", + "controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n" + "- none: leaves thoughts unparsed in `message.content`\n" + "- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)\n" + "(default: deepseek)", [](common_params & params, const std::string & value) { /**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; } else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; } - else { std::invalid_argument("invalid value"); } + else { throw std::invalid_argument("invalid value"); } } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK")); + add_opt(common_arg( + {"--reasoning-budget"}, "N", + "controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)", + [](common_params & params, int value) { + if (value != 0 && value != -1) { throw std::invalid_argument("invalid value"); } + params.reasoning_budget = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK_BUDGET")); add_opt(common_arg( {"--chat-template"}, "JINJA_TEMPLATE", string_format( @@ -2868,7 +2892,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.chat_template = value; } - ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_LLAVA}).set_env("LLAMA_ARG_CHAT_TEMPLATE")); + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_CHAT_TEMPLATE")); add_opt(common_arg( {"--chat-template-file"}, "JINJA_TEMPLATE_FILE", string_format( @@ -2955,7 +2979,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { /**/ if (value == "jsonl") { params.batched_bench_output_jsonl = true; } else if (value == "md") { params.batched_bench_output_jsonl = false; } - else { std::invalid_argument("invalid value"); } + else { throw std::invalid_argument("invalid value"); } } ).set_examples({LLAMA_EXAMPLE_BENCH})); add_opt(common_arg( @@ -2987,6 +3011,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex common_log_set_verbosity_thold(INT_MAX); } )); + add_opt(common_arg( + {"--offline"}, + "Offline mode: forces use of cache, prevents network access", + [](common_params & params) { + params.offline = true; + } + ).set_env("LLAMA_OFFLINE")); add_opt(common_arg( {"-lv", "--verbosity", "--log-verbosity"}, "N", "Set the verbosity threshold. Messages with a higher verbosity will be ignored.", diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp new file mode 100644 index 000000000..65b664cb3 --- /dev/null +++ b/common/chat-parser.cpp @@ -0,0 +1,380 @@ +#include "chat-parser.h" +#include "common.h" +#include "log.h" +#include "regex-partial.h" + +#include +#include +#include +#include + +using json = nlohmann::ordered_json; + +common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax) + : input_(input), is_partial_(is_partial), syntax_(syntax) +{ + result_.role = "assistant"; + + while (true) { + std::string id = std::to_string(std::rand()); + if (input.find(id) == std::string::npos) { + healing_marker_ = id; + break; + } + } +} + +std::string common_chat_msg_parser::str(const common_string_range & rng) const { + GGML_ASSERT(rng.begin <= rng.end); + return input_.substr(rng.begin, rng.end - rng.begin); +} + +void common_chat_msg_parser::add_content(const std::string &content) { + result_.content += content; +} + +void common_chat_msg_parser::add_reasoning_content(const std::string &reasoning_content) { + result_.reasoning_content += reasoning_content; +} + +bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::string & id, const std::string & arguments) { + if (name.empty()) { + return false; + } + + common_chat_tool_call tool_call; + tool_call.name = name; + tool_call.arguments = arguments; + tool_call.id = id; + + // LOG_DBG("Tool call arguments:\n\traw: %s\n\tresult: %s\n", arguments.c_str(), tool_call.arguments.c_str()); + result_.tool_calls.emplace_back(tool_call); + return true; +} +bool common_chat_msg_parser::add_tool_call(const json & tool_call) { + std::string name = tool_call.contains("name") ? tool_call.at("name") : ""; + std::string id = tool_call.contains("id") ? tool_call.at("id") : ""; + std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : ""; + return add_tool_call(name, id, arguments); +} + +bool common_chat_msg_parser::add_tool_calls(const json & arr) { + for (const auto & item : arr) { + if (!add_tool_call(item)) { + return false; + } + } + return true; +} +void common_chat_msg_parser::finish() { + if (!is_partial_ && pos_ != input_.size()) { + throw std::runtime_error("Unexpected content at end of input");// + input_.substr(pos_)); + } +} + +bool common_chat_msg_parser::consume_spaces() { + const auto length = input_.size(); + auto consumed = false; + while (pos_ < length && std::isspace(input_[pos_])) { + ++pos_; + consumed = true; + } + return consumed; +} + +bool common_chat_msg_parser::try_consume_literal(const std::string & literal) { + auto pos = pos_; + for (auto i = 0u; i < literal.size(); ++i) { + if (pos >= input_.size()) { + return false; + } + if (input_[pos] != literal[i]) { + return false; + } + ++pos; + } + pos_ = pos; + return true; +} + +std::optional common_chat_msg_parser::try_find_literal(const std::string & literal) { + auto idx = input_.find(literal, pos_); + if (idx != std::string::npos) { + find_regex_result res; + res.prelude = input_.substr(pos_, idx - pos_); + auto end = idx + literal.size(); + res.groups.emplace_back(common_string_range{idx, end}); + move_to(end); + return res; + } + if (is_partial_) { + idx = string_find_partial_stop(input_, literal); + if (idx != std::string::npos && idx >= pos_) { + find_regex_result res; + res.prelude = input_.substr(pos_, idx - pos_); + auto end = input_.size(); + res.groups.emplace_back(common_string_range{idx, end}); + move_to(end); + return res; + } + } + return std::nullopt; +} + +void common_chat_msg_parser::consume_literal(const std::string & literal) { + if (!try_consume_literal(literal)) { + throw common_chat_msg_partial_exception(literal); + } +} + +bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think, const std::string & end_think) { + auto handle_reasoning = [&](const std::string & reasoning, bool closed) { + auto stripped_reasoning = string_strip(reasoning); + if (stripped_reasoning.empty()) { + return; + } + if (syntax_.reasoning_in_content) { + add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "" : start_think); + add_content(stripped_reasoning); + if (closed) { + add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "" : end_think); + } + } else { + add_reasoning_content(stripped_reasoning); + } + }; + if (syntax_.reasoning_format != COMMON_REASONING_FORMAT_NONE) { + if (syntax_.thinking_forced_open || try_consume_literal(start_think)) { + if (auto res = try_find_literal(end_think)) { + handle_reasoning(res->prelude, /* closed */ true); + consume_spaces(); + return true; + } + auto rest = consume_rest(); + if (!rest.empty()) { + handle_reasoning(rest, /* closed */ !is_partial()); + } + // Allow unclosed thinking tags, for now (https://github.com/ggml-org/llama.cpp/issues/13812, https://github.com/ggml-org/llama.cpp/issues/13877) + // if (!syntax_.thinking_forced_open) { + // throw common_chat_msg_partial_exception(end_think); + // } + return true; + } + } + return false; +} + +std::string common_chat_msg_parser::consume_rest() { + auto rest = input_.substr(pos_); + pos_ = input_.size(); + return rest; +} + +// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback. +std::optional common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from, bool add_prelude_to_content) { + auto m = regex.search(input_, from == std::string::npos ? pos_ : from); + if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) { + return std::nullopt; + } + auto prelude = input_.substr(pos_, m.groups[0].begin - pos_); + pos_ = m.groups[0].end; + + if (add_prelude_to_content) { + add_content(prelude); + } + if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) { + if (is_partial()) { + throw common_chat_msg_partial_exception(regex.str()); + } + return std::nullopt; + } + return find_regex_result{prelude, m.groups}; +} + +common_chat_msg_parser::find_regex_result common_chat_msg_parser::consume_regex(const common_regex & regex) { + if (auto result = try_consume_regex(regex)) { + return *result; + } + throw common_chat_msg_partial_exception(regex.str()); +} + +std::optional common_chat_msg_parser::try_consume_regex(const common_regex & regex) { + auto m = regex.search(input_, pos_); + if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) { + return std::nullopt; + } + if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) { + if (is_partial()) { + throw common_chat_msg_partial_exception(regex.str()); + } + return std::nullopt; + } + if (m.groups[0].begin != pos_) { + // Didn't match at the current position. + return std::nullopt; + } + pos_ = m.groups[0].end; + + return find_regex_result { + /* .prelude = */ "", + m.groups, + }; +} + +std::optional common_chat_msg_parser::try_consume_json() { + auto it = input_.cbegin() + pos_; + const auto end = input_.cend(); + common_json result; + if (!common_json_parse(it, end, healing_marker_, result)) { + return std::nullopt; + } + pos_ = std::distance(input_.cbegin(), it); + if (result.healing_marker.marker.empty()) { + // No healing marker, just return the parsed json + return result; + } + if (!is_partial()) { + throw common_chat_msg_partial_exception("JSON"); + } + return result; +} + +common_json common_chat_msg_parser::consume_json() { + if (auto result = try_consume_json()) { + return *result; + } + throw common_chat_msg_partial_exception("JSON"); +} + +common_chat_msg_parser::consume_json_result common_chat_msg_parser::consume_json_with_dumped_args( + const std::vector> & args_paths, + const std::vector> & content_paths +) { + if (auto result = try_consume_json_with_dumped_args(args_paths, content_paths)) { + return *result; + } + throw common_chat_msg_partial_exception("JSON"); +} + +std::optional common_chat_msg_parser::try_consume_json_with_dumped_args( + const std::vector> & args_paths, + const std::vector> & content_paths +) { + auto partial = try_consume_json(); + if (!partial) { + return std::nullopt; + } + auto is_arguments_path = [&](const std::vector & path) { + return std::find(args_paths.begin(), args_paths.end(), path) != args_paths.end(); + }; + auto is_content_path = [&](const std::vector & path) { + return std::find(content_paths.begin(), content_paths.end(), path) != content_paths.end(); + }; + + if (partial->healing_marker.marker.empty()) { + if (args_paths.empty()) { + // No arguments to dump, and JSON was parsed fully. + return consume_json_result { + partial->json, + /* .is_partial = */ false, + }; + } + if (is_arguments_path({})) { + // Entire JSON is the arguments and was parsed fully. + return consume_json_result { + partial->json.dump(), + /* .is_partial = */ false, + }; + } + } + + LOG_DBG("Parsed partial JSON: %s (json_healing_marker: %s)\n", partial->json.dump().c_str(), partial->healing_marker.json_dump_marker.c_str()); + + auto found_healing_marker = false; + std::vector path; + std::function remove_unsupported_healings_and_dump_args = [&](const json & j) -> json { + if (is_arguments_path(path)) { + auto arguments = j.dump(); + if (is_partial() && !partial->healing_marker.marker.empty()) { + auto idx = arguments.find(partial->healing_marker.json_dump_marker); + if (idx != std::string::npos) { + arguments.resize(idx); + found_healing_marker = true; + } + if (arguments == "\"") { + // This happens because of completing `:"$magic` after `"arguments"` + arguments = ""; + } + } + return arguments; + } + if (is_content_path(path)) { + if (!j.is_string()) { + throw std::runtime_error("Content path must be a string"); + } + std::string str = j; + auto idx = str.find(partial->healing_marker.marker); // not using json_dump_marker as we're inside a string + if (idx != std::string::npos) { + str.resize(idx); + found_healing_marker = true; + } + return str; + } + if (j.is_object()) { + auto obj = json::object(); + for (const auto & p : j.items()) { + const auto & key = p.key(); + const auto & value = p.value(); + const std::string key_str = key; // NOLINT + auto idx = key_str.find(healing_marker_); + if (idx != std::string::npos) { + found_healing_marker = true; + break; + } + path.push_back(key_str); + if (value.is_string()) { + const std::string value_str = value; + if (value_str.find(healing_marker_) != std::string::npos) { + found_healing_marker = true; + if (is_content_path(path)) { + if (partial->healing_marker.marker == partial->healing_marker.json_dump_marker) { + // The healing occurred inside the string: good. Otherwise we just ditch the entire key/value pair. + obj[key] = remove_unsupported_healings_and_dump_args(value); + } + } + break; + } + obj[key] = value; + } else { + obj[key] = remove_unsupported_healings_and_dump_args(value); + } + path.pop_back(); + } + return obj; + } + if (j.is_array()) { + auto arr = json::array(); + for (const auto & value : j) { + if (value.is_string()) { + std::string str = value; + auto idx = str.find(healing_marker_); + if (idx != std::string::npos) { + // Don't heal array values that aren't in the arguments. + found_healing_marker = true; + break; + } + } + arr.push_back(remove_unsupported_healings_and_dump_args(value)); + } + return arr; + } + return j; + }; + + auto cleaned = remove_unsupported_healings_and_dump_args(partial->json); + LOG_DBG("Cleaned up JSON %s to %s (json_healing_marker : '%s')\n", partial->json.dump().c_str(), cleaned.dump().c_str(), partial->healing_marker.json_dump_marker.c_str()); + return consume_json_result { + cleaned, + /* .is_partial = */ found_healing_marker, + }; +} diff --git a/common/chat-parser.h b/common/chat-parser.h new file mode 100644 index 000000000..7ee355056 --- /dev/null +++ b/common/chat-parser.h @@ -0,0 +1,118 @@ +#pragma once + +#include "chat.h" +#include "json-partial.h" +#include "regex-partial.h" + +#include + +#include +#include +#include + +class common_chat_msg_partial_exception : public std::runtime_error { + public: + common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {} +}; + +class common_chat_msg_parser { + std::string input_; + bool is_partial_; + common_chat_syntax syntax_; + std::string healing_marker_; + + size_t pos_ = 0; + common_chat_msg result_; + + public: + common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax); + const std::string & input() const { return input_; } + size_t pos() const { return pos_; } + const std::string & healing_marker() const { return healing_marker_; } + const bool & is_partial() const { return is_partial_; } + const common_chat_msg & result() const { return result_; } + const common_chat_syntax & syntax() const { return syntax_; } + + void move_to(size_t pos) { + if (pos > input_.size()) { + throw std::runtime_error("Invalid position!"); + } + pos_ = pos; + } + void move_back(size_t n) { + if (pos_ < n) { + throw std::runtime_error("Can't move back that far!"); + } + pos_ -= n; + } + + // Get the substring of the input at the given range + std::string str(const common_string_range & rng) const; + + // Appends to the result.content field + void add_content(const std::string & content); + + // Appends to the result.reasoning_content field + void add_reasoning_content(const std::string & reasoning_content); + + // Adds a tool call to the result. If the tool call is too incomplete (e.g. name empty), it won't add anything. + bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments); + + // Adds a tool call using the "name", "id" and "arguments" fields of the json object + bool add_tool_call(const nlohmann::ordered_json & tool_call); + + // Adds an array of tool calls using their "name", "id" and "arguments" fields. + bool add_tool_calls(const nlohmann::ordered_json & arr); + + void finish(); + + bool consume_spaces(); + + void consume_literal(const std::string & literal); + + bool try_parse_reasoning(const std::string & start_think, const std::string & end_think); + + std::string consume_rest(); + + struct find_regex_result { + std::string prelude; + std::vector groups; + }; + + std::optional try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true); + + bool try_consume_literal(const std::string & literal); + + std::optional try_find_literal(const std::string & literal); + + find_regex_result consume_regex(const common_regex & regex); + + std::optional try_consume_regex(const common_regex & regex); + + std::optional try_consume_json(); + common_json consume_json(); + + struct consume_json_result { + nlohmann::ordered_json value; + bool is_partial; + }; + + /* + Consume (possibly partial) json and converts specific subtrees to (possibly truncated) JSON strings. + + By default, object keys can't be truncated, nor can string values (their corresponding key is removed, + e.g. `{"foo": "bar", "baz": "b` -> `{"foo": "bar"}` + + But one can allow subpaths to be kept truncated, and possibly json-dumped to truncated json strings + - with `content_paths={{"foo"}}` -> `{"foo": "b` -> {"foo": "b"}` + - with `args_paths={{"foo"}}` -> `{"foo": {"b` -> `{"foo": "{b"}` + */ + consume_json_result consume_json_with_dumped_args( + const std::vector> & args_paths = {}, + const std::vector> & content_paths = {} + ); + std::optional try_consume_json_with_dumped_args( + const std::vector> & args_paths = {}, + const std::vector> & content_paths = {} + ); +}; diff --git a/common/chat.cpp b/common/chat.cpp index f138c7bca..f1ab4c85a 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1,10 +1,21 @@ #include "chat.h" +#include "chat-parser.h" +#include "common.h" +#include "json-partial.h" #include "json-schema-to-grammar.h" #include "log.h" -#include "minja/chat-template.hpp" -#include "minja/minja.hpp" +#include "regex-partial.h" +#include +#include + +#include +#include +#include #include +#include +#include +#include static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) { auto time = std::chrono::system_clock::to_time_t(now); @@ -15,6 +26,101 @@ static std::string format_time(const std::chrono::system_clock::time_point & now return res; } +static std::string string_diff(const std::string & last, const std::string & current) { + if (last.empty()) { + return current; + } + if (!string_starts_with(current, last)) { + if (string_starts_with(last, current)) { + // This happens if the last generation ended on a partial stop word (not erased), + // and the current ended on a stop word (erased). + return ""; + } + throw std::runtime_error("Invalid diff: '" + last + "' not found at start of '" + current + "'"); + } + return current.substr(last.size()); +} + +static bool has_content_or_tool_calls(const common_chat_msg & msg) { + return !msg.content.empty() || !msg.tool_calls.empty(); +} + +template <> +json common_chat_msg::to_json_oaicompat() const +{ + json message { + {"role", "assistant"}, + }; + if (!reasoning_content.empty()) { + message["reasoning_content"] = reasoning_content; + } + if (content.empty() && !tool_calls.empty()) { + message["content"] = json(); + } else { + message["content"] = content; + } + if (!tool_calls.empty()) { + auto arr = json::array(); + for (const auto & tc : tool_calls) { + arr.push_back({ + {"type", "function"}, + {"function", { + {"name", tc.name}, + {"arguments", tc.arguments}, + }}, + {"id", tc.id}, + // // Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo). + // // We only generate a random id for the ones that don't generate one by themselves + // // (they also won't get to see it as their template likely doesn't use it, so it's all for the client) + // {"id", tc.id.empty() ? gen_tool_call_id() : tc.id}, + }); + } + message["tool_calls"] = arr; + } + return message; +} + +std::vector common_chat_msg_diff::compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg) { + std::vector diffs; + // if (previous_msg.reasoning_content != current.reasoning_content) { + // auto & diff = diffs.emplace_back(); + // diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, current.reasoning_content); + // } + if (previous_msg.content != new_msg.content) { + auto & diff = diffs.emplace_back(); + diff.content_delta = string_diff(previous_msg.content, new_msg.content); + } + + if (new_msg.tool_calls.size() < previous_msg.tool_calls.size()) { + throw std::runtime_error("Invalid diff: now finding less tool calls!"); + } + + if (!previous_msg.tool_calls.empty()) { + auto idx = previous_msg.tool_calls.size() - 1; + const auto & pref = previous_msg.tool_calls[idx]; + const auto & newf = new_msg.tool_calls[idx]; + if (pref.name != newf.name) { + throw std::runtime_error("Invalid diff: tool call mismatch!"); + } + auto args_diff = string_diff(pref.arguments, newf.arguments); + if (!args_diff.empty() || pref.id != newf.id) { + auto & diff = diffs.emplace_back(); + diff.tool_call_index = idx; + if (pref.id != newf.id) { + diff.tool_call_delta.id = newf.id; + diff.tool_call_delta.name = newf.name; + } + diff.tool_call_delta.arguments = args_diff; + } + } + for (size_t idx = previous_msg.tool_calls.size(); idx < new_msg.tool_calls.size(); ++idx) { + auto & diff = diffs.emplace_back(); + diff.tool_call_index = idx; + diff.tool_call_delta = new_msg.tool_calls[idx]; + } + return diffs; +} + typedef minja::chat_template common_chat_template; struct common_chat_templates { @@ -32,7 +138,7 @@ struct templates_params { bool stream; std::string grammar; bool add_generation_prompt = true; - bool extract_reasoning = true; + bool enable_thinking = true; std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); }; @@ -277,6 +383,32 @@ json common_chat_tools_to_json_oaicompat(const std::vector & t return result; } +template <> json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) { + json delta = json::object(); + // if (!diff.reasoning_content_delta.empty()) { + // delta["reasoning_content"] = msg.reasoning_content; + // } + if (!diff.content_delta.empty()) { + delta["content"] = diff.content_delta; + } + if (diff.tool_call_index != std::string::npos) { + json tool_call; + tool_call["index"] = diff.tool_call_index; + if (!diff.tool_call_delta.id.empty()) { + tool_call["id"] = diff.tool_call_delta.id; + tool_call["type"] = "function"; + } + json function = json::object(); + if (!diff.tool_call_delta.name.empty()) { + function["name"] = diff.tool_call_delta.name; + } + function["arguments"] = diff.tool_call_delta.arguments; + tool_call["function"] = function; + delta["tool_calls"] = json::array({tool_call}); + } + return delta; +} + bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { if (use_jinja) { try { @@ -444,7 +576,7 @@ common_chat_templates_ptr common_chat_templates_init( return tmpls; } -std::string common_chat_format_name(common_chat_format format) { +const char * common_chat_format_name(common_chat_format format) { switch (format) { case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only"; case COMMON_CHAT_FORMAT_GENERIC: return "Generic"; @@ -452,182 +584,127 @@ std::string common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x"; case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools"; case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1"; - case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING: return "DeepSeek R1 (extract reasoning)"; case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2"; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2"; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1"; case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro"; - case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING: return "Hermes 2 Pro (extract reasoning)"; case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B"; - case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: return "Command R7B (extract reasoning)"; default: throw std::runtime_error("Unknown chat format"); } } -static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) { - // // https://json.nlohmann.me/features/parsing/sax_interface/ - struct json_error_locator : public nlohmann::json_sax { - std::size_t position; - bool found_error; +const char * common_reasoning_format_name(common_reasoning_format format) { + switch (format) { + case COMMON_REASONING_FORMAT_NONE: return "none"; + case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek"; + default: + throw std::runtime_error("Unknown reasoning format"); + } +} - json_error_locator() : position(0), found_error(false) {} - - bool parse_error(std::size_t position, const std::string &, const json::exception &) override { // NOLINT - this->position = position - 1; - this->found_error = true; - return false; +static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) { + std::string arguments; + if (builder.is_partial()) { + arguments = (json {{"code", code + builder.healing_marker()}}).dump(); + auto idx = arguments.find(builder.healing_marker()); + if (idx != std::string::npos) { + arguments.resize(idx); } - bool null() override { return true; } // NOLINT - bool boolean(bool) override { return true; } // NOLINT - bool number_integer(number_integer_t) override { return true; } // NOLINT - bool number_unsigned(number_unsigned_t) override { return true; } // NOLINT - bool number_float(number_float_t, const string_t &) override { return true; } // NOLINT - bool string(string_t &) override { return true; } // NOLINT - bool binary(binary_t &) override { return true; } // NOLINT - bool start_object(std::size_t) override { return true; } // NOLINT - bool key(string_t &) override { return true; } // NOLINT - bool end_object() override { return true; } - bool start_array(std::size_t) override { return true; } // NOLINT - bool end_array() override { return true; } - }; - json_error_locator err_loc; - json::sax_parse(it, end, &err_loc); - - std::string::const_iterator temptative_end; - if (err_loc.found_error) { - temptative_end = it + err_loc.position; } else { - temptative_end = end; - } - std::string json_sub {it, temptative_end}; - try { - out = json::parse(json_sub); - it = temptative_end; - return true; - } catch (const std::exception &) { - return false; - } -} - -static bool parse_literal(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) { - auto expected_it = expected.begin(); - auto tmp_it = it; - while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) { - ++tmp_it; - ++expected_it; - } - if (expected_it == expected.end()) { - it = tmp_it; - return true; - } - return false; -} - -static std::optional parse_pattern(std::string::const_iterator & it, const std::string::const_iterator & end, const std::regex & expected) { - std::smatch match; - if (std::regex_match(it, end, match, expected)) { - it = match.suffix().first; - return match; - } - return std::nullopt; -} - -static void consume_spaces(std::string::const_iterator & it, const std::string::const_iterator & end) { - while (it != end && std::isspace(*it)) { - ++it; + arguments = (json {{"code", code}}).dump(); } + return arguments; } /** * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. * Aggregates the prefix, suffix and in-between text into the content. */ -static common_chat_msg parse_json_tool_calls( - const std::string& input, - const std::optional & trigger_opt, - const std::regex & function_regex, - const std::regex & close_regex, - bool allow_raw_python = false) { - std::smatch match; +static void parse_json_tool_calls( + common_chat_msg_parser & builder, + const std::optional & block_open, + const std::optional & function_regex_start_only, + const std::optional & function_regex, + const common_regex & close_regex, + const std::optional & block_close, + bool allow_raw_python = false, + const std::function & get_function_name = nullptr) { - common_chat_msg result; - result.role = "assistant"; + auto parse_tool_calls = [&]() { + size_t from = std::string::npos; + auto first = true; + while (true) { + auto res = function_regex_start_only && first + ? builder.try_consume_regex(*function_regex_start_only) + : function_regex + ? builder.try_find_regex(*function_regex, from) + : std::nullopt; + if (res) { + std::string name; + if (get_function_name) { + name = get_function_name(*res); + } else { + GGML_ASSERT(res->groups.size() == 2); + name = builder.str(res->groups[1]); + } + first = false; + if (name.empty()) { + // get_function_name signalled us that we should skip this match and treat it as content. + from = res->groups[0].begin + 1; + continue; + } + from = std::string::npos; - - auto end = input.end(); - auto it = input.begin(); - - if (trigger_opt) { - if (!std::regex_search(it, end, match, *trigger_opt)) { - result.content = input; - return result; - } - result.content = match.prefix().str(); - it = match.suffix().first; - } - - while (it != end) { - std::sregex_iterator rend; - std::sregex_iterator rit(it, end, function_regex); - if (rit == rend) { - result.content += std::string(it, end); + auto maybe_raw_python = name == "python" && allow_raw_python; + if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) { + if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { + if (!builder.add_tool_call(name, "", arguments->value) || arguments->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_regex(close_regex); + } + continue; + } + if (maybe_raw_python) { + auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); + if (!builder.add_tool_call(name, "", arguments)) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + return; + } + throw common_chat_msg_partial_exception("incomplete tool call"); + } break; } - auto name = rit->str(1); - result.content += std::string(it, rit->prefix().second); - it = rit->suffix().first; - - json arguments; - if (parse_json(it, end, arguments)) { - if (!std::regex_search(it, end, match, close_regex)) { - throw std::runtime_error("Malformed input, missing closing pattern: " + input); - } - it = match.suffix().first; - result.tool_calls.push_back({name, arguments.is_string() ? arguments.get() : arguments.dump(), /* id= */ ""}); - } else { - if (allow_raw_python && name == "python") { - result.tool_calls.push_back({name, json({{"code", std::string(it, end)}}).dump(), /* id= */ ""}); - break; - } - throw std::runtime_error("Failed to parse json tool call arguments: " + input); + if (block_close) { + builder.consume_regex(*block_close); } - } - - if (!result.tool_calls.empty()) { - if (!string_strip(result.content).empty()) { - LOG_WRN("Content found with tool calls: %s\n", result.content.c_str()); - } - result.content = ""; - } - return result; -} - -static common_chat_tool_call process_tool_call(const json & tool_call) { - const auto & arguments = tool_call.at("arguments"); - return { - /* .name = */ tool_call.at("name"), - /* .arguments = */ arguments.is_string() ? arguments.get() : arguments.dump(), - /* .id = */ tool_call.contains("id") ? tool_call.at("id") : "", + builder.consume_spaces(); + builder.add_content(builder.consume_rest()); }; -} -static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) { - auto content_end = input.find(prefix); - size_t tc_start = std::string::npos; - - common_chat_msg result; - result.role = "assistant"; - if (content_end == std::string::npos) { - result.content = input; - } else { - tc_start = content_end + prefix.size() - rstrip_prefix; - result.content = input.substr(0, content_end); - auto tool_calls = json::parse(input.substr(tc_start)); - for (const auto & tool_call : tool_calls) { - result.tool_calls.emplace_back(process_tool_call(tool_call)); + if (block_open) { + if (auto res = builder.try_find_regex(*block_open)) { + parse_tool_calls(); + } else { + builder.add_content(builder.consume_rest()); } + } else { + parse_tool_calls(); + } +} + +static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, const common_regex & prefix, size_t rstrip_prefix = 0) { + static const std::vector> args_paths = {{"arguments"}}; + if (auto res = builder.try_find_regex(prefix)) { + builder.move_back(rstrip_prefix); + auto tool_calls = builder.consume_json_with_dumped_args(args_paths); + if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call array"); + } + } else { + builder.add_content(builder.consume_rest()); } - return result; } static void foreach_function(const json & tools, const std::function & fn) { @@ -754,29 +831,36 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp data.format = COMMON_CHAT_FORMAT_GENERIC; return data; } -static common_chat_msg common_chat_parse_generic(const std::string & input) { - json data = json::parse(input); - common_chat_msg result; - result.role = "assistant"; - if (data.contains("tool_calls")) { - for (const auto & tool_call : data.at("tool_calls")) { - result.tool_calls.push_back({ - tool_call.at("name"), - tool_call.at("arguments").dump(), - tool_call.contains("id") ? tool_call.at("id") : "", - }); - } - } else if (data.contains("tool_call")) { - result.tool_calls.push_back({ - data.at("tool_call").at("name"), - data.at("tool_call").at("arguments").dump(), - /* id= */ "", - }); - } else if (data.contains("response")) { - const auto & response = data.at("response"); - result.content = response.is_string() ? response.get() : response.dump(2); +static void common_chat_parse_generic(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + static const std::vector> content_paths = { + {"response"}, + }; + static const std::vector> args_paths = { + {"tool_call", "arguments"}, + {"tool_calls", "arguments"}, + }; + auto data = builder.consume_json_with_dumped_args(args_paths, content_paths); + if (data.value.contains("tool_calls")) { + if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool calls"); + } + } else if (data.value.contains("tool_call")) { + if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } else if (data.value.contains("response")) { + const auto & response = data.value.at("response"); + builder.add_content(response.is_string() ? response.template get() : response.dump(2)); + if (data.is_partial) { + throw common_chat_msg_partial_exception("incomplete response"); + } + } else { + throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON"); } - return result; } static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) { @@ -823,12 +907,44 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO; return data; } -static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input) { - return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]"); +static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex prefix(regex_escape("[TOOL_CALLS]")); + parse_prefixed_json_tool_call_array(builder, prefix); } static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; + + auto adjusted_messages = json::array(); + for (const auto & msg : inputs.messages) { + auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string(); + auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array(); + if (has_reasoning_content && has_tool_calls) { + auto adjusted_message = msg; + adjusted_message["tool_plan"] = msg.at("reasoning_content"); + adjusted_message.erase("reasoning_content"); + adjusted_messages.push_back(adjusted_message); + } else { + adjusted_messages.push_back(msg); + } + } + data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {}); + data.format = COMMON_CHAT_FORMAT_COMMAND_R7B; + if (string_ends_with(data.prompt, "<|START_THINKING|>")) { + if (!inputs.enable_thinking) { + data.prompt += "<|END_THINKING|>"; + } else { + data.thinking_forced_open = true; + } + } else if (!inputs.enable_thinking && string_ends_with(data.prompt, "<|CHATBOT_TOKEN|>")) { + data.prompt += "<|START_THINKING|><|END_THINKING|>"; + } + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); @@ -859,11 +975,16 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_ if (!inputs.parallel_tool_calls) { schema["maxItems"] = 1; } - builder.add_rule("root", "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\""); + builder.add_rule("root", + std::string(data.thinking_forced_open ? "( \"<|END_THINKING|>\" space )? " : "") + + "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\""); }); data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_WORD, - "<|START_ACTION|>", + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, + // If thinking_forced_open, then we capture the tag in the grammar, + // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) + std::string(data.thinking_forced_open ? "[\\s\\S]*?(<\\|END_THINKING\\|>\\s*)" : "(?:<\\|START_THINKING\\|>[\\s\\S]*?<\\|END_THINKING\\|>\\s*)?") + + "(<\\|START_ACTION\\|>)[\\s\\S]*" }); data.preserved_tokens = { "<|START_ACTION|>", @@ -873,61 +994,40 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_ "<|START_THINKING|>", "<|END_THINKING|>", }; - auto adjusted_messages = json::array(); - for (const auto & msg : inputs.messages) { - auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string(); - auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array(); - if (has_reasoning_content && has_tool_calls) { - auto adjusted_message = msg; - adjusted_message["tool_plan"] = msg.at("reasoning_content"); - adjusted_message.erase("reasoning_content"); - adjusted_messages.push_back(adjusted_message); - } else { - adjusted_messages.push_back(msg); - } - } - data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {}); - data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING : COMMON_CHAT_FORMAT_COMMAND_R7B; return data; } -static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool extract_reasoning) { - static const std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S]*?)<\\|END_THINKING\\|>)([\\s\\S]*)"); - static const std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S]*?)<\\|END_ACTION\\|>"); - static const std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S]*?)<\\|END_RESPONSE\\|>"); - std::smatch match; +static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>"); - common_chat_msg result; - result.role = "assistant"; + static const common_regex start_action_regex("<\\|START_ACTION\\|>"); + static const common_regex end_action_regex("<\\|END_ACTION\\|>"); + static const common_regex start_response_regex("<\\|START_RESPONSE\\|>"); + static const common_regex end_response_regex("<\\|END_RESPONSE\\|>"); - std::string rest = input; - - if (std::regex_match(rest, match, thought_regex)) { - if (extract_reasoning) { - result.reasoning_content = match[2].str(); - } else if (!match[2].str().empty()) { - // Let the unparsed thinking tags through in content only if their insides aren't empty. - result.content = match[1].str(); + if (auto res = builder.try_find_regex(start_action_regex)) { + // If we didn't extract thoughts, prelude includes them. + auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}}); + for (const auto & tool_call : tool_calls.value) { + std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : ""; + std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : ""; + std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : ""; + if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } } - rest = match[3].str(); - } - if (std::regex_match(rest, match, action_regex)) { - auto actions_str = match[1].str(); - auto actions = json::parse(actions_str); - for (const auto & action : actions) { - result.tool_calls.push_back({ - /* .name = */ action.at("tool_name"), - /* .arguments = */ action.at("parameters").dump(), - /* .id = */ action.at("tool_call_id"), - }); + if (tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_regex(end_action_regex); + } else if (auto res = builder.try_find_regex(start_response_regex)) { + if (!builder.try_find_regex(end_response_regex)) { + builder.add_content(builder.consume_rest()); + throw common_chat_msg_partial_exception(end_response_regex.str()); } - } else if (std::regex_match(rest, match, response_regex)) { - auto response = match[1].str(); - result.content += response; } else { - result.content += rest; + builder.add_content(builder.consume_rest()); } - return result; } static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector & expected_properties) { @@ -1004,8 +1104,8 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te }); // Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name. data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START, - "\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*", + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, + "(\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\")[\\s\\S]*", // + name + "\"[\\s\\S]*", }); if (!builtin_tools.empty()) { data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"}); @@ -1028,78 +1128,64 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te }); return data; } -static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) { - // TODO: tighten & simplify the parser, don't accept leading text context. - static const std::regex function_regex( +static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex function_regex( "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: "); - static const std::regex close_regex("\\}\\s*"); - static const std::regex builtin_call_regex("<\\|python_tag\\|>\\s*([^.(]+)\\s*\\.\\s*call\\s*\\(\\s*([\\w]+)\\s*=\\s*([\\s\\S]*?)\\)"); + static const common_regex close_regex("\\}\\s*"); + + static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\("); + static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*"); if (with_builtin_tools) { - std::smatch match; - if (std::regex_match(input, match, builtin_call_regex)) { - try { - auto name = match[1].str(); - auto arg_name = match[2].str(); - auto arg_value_str = match[3].str(); - auto arg_value = json::parse(arg_value_str); + static const common_regex builtin_call_regex("<\\|python_tag\\|>"); + if (auto res = builder.try_find_regex(builtin_call_regex)) { + auto fun_res = builder.consume_regex(function_name_regex); + auto function_name = builder.str(fun_res.groups[1]); - common_chat_msg msg; - msg.role = "assistant"; - msg.tool_calls.push_back({ - /* .name = */ name, - /* .arguments = */ (json { - {arg_name, arg_value}, - }).dump(), - /* .id = */ "", - }); - return msg; - } catch (const std::exception & e) { - LOG_WRN("Failed to parse builtin tool call arguments (%s): %s", e.what(), input.c_str()); + common_healing_marker healing_marker; + json args = json::object(); + while (true) { + if (auto arg_res = builder.try_consume_regex(arg_name_regex)) { + auto arg_name = builder.str(arg_res->groups[1]); + auto partial = builder.consume_json(); + args[arg_name] = partial.json; + healing_marker.marker = partial.healing_marker.marker; + healing_marker.json_dump_marker = partial.healing_marker.json_dump_marker; + builder.consume_spaces(); + if (!builder.try_consume_literal(",")) { + break; + } + } else { + break; + } } + builder.consume_literal(")"); + builder.consume_spaces(); + + auto arguments = args.dump(); + if (!builder.add_tool_call(function_name, "", arguments)) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + return; } } - return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex); + parse_json_tool_calls( + builder, + /* block_open= */ std::nullopt, + /* function_regex_start_only= */ function_regex, + /* function_regex= */ std::nullopt, + close_regex, + std::nullopt); + } static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; - if (inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null(); - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - builder.resolve_refs(parameters); - tool_rules.push_back(builder.add_rule(name + "-call", - "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n" - "```json\\n\" " + builder.add_schema(name + "-args", parameters) + " " - "\"```<|tool▁call▁end|>\"")); - }); - // Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag, - // so we accept common variants (then it's all constrained) - builder.add_rule("root", - "( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" ) " - "(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " " - "\"<|tool▁calls▁end|>\"" - " space"); - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool▁calls▁begin|>"}); - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_calls_begin|>"}); - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool calls begin|>"}); - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool\\_calls\\_begin|>"}); - data.preserved_tokens = { - "", - "", - "<|tool▁calls▁begin|>", - "<|tool▁call▁begin|>", - "<|tool▁sep|>", - "<|tool▁call▁end|>", - "<|tool▁calls▁end|", - }; - }); - } auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); // Hacks to fix the official (broken) prompt. @@ -1120,45 +1206,76 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_ "$1<|tool▁calls▁end|><|end▁of▁sentence|>$2"); } data.prompt = prompt; - data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING : COMMON_CHAT_FORMAT_DEEPSEEK_R1; + data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1; + if (string_ends_with(data.prompt, "\n")) { + if (!inputs.enable_thinking) { + data.prompt += ""; + } else { + data.thinking_forced_open = true; + } + } + + if (inputs.tools.is_array() && !inputs.tools.empty()) { + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null(); + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); + tool_rules.push_back(builder.add_rule(name + "-call", + "( \"<|tool▁call▁begin|>\" )? \"function<|tool▁sep|>" + name + "\\n" + "```json\\n\" " + builder.add_schema(name + "-args", parameters) + " " + "\"```<|tool▁call▁end|>\"")); + }); + // Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag, + // so we accept common variants (then it's all constrained) + builder.add_rule("root", + std::string(data.thinking_forced_open ? "( \"\" space )? " : "") + + "( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" | \"<|tool▁calls|>\" ) " + "(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " " + "\"<|tool▁calls▁end|>\"" + " space"); + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, + // If thinking_forced_open, then we capture the tag in the grammar, + // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) + std::string(data.thinking_forced_open ? "[\\s\\S]*?(\\s*)" : "(?:[\\s\\S]*?\\s*)?") + + "(<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)[\\s\\S]*" + }); + data.preserved_tokens = { + "", + "", + "<|tool▁calls▁begin|>", + "<|tool▁call▁begin|>", + "<|tool▁sep|>", + "<|tool▁call▁end|>", + "<|tool▁calls▁end|", + }; + }); + } return data; } -static common_chat_msg handle_think_tag_prelude(const std::string & input, bool extract_reasoning, const std::function & rest_parser) { - std::smatch match; - static const std::regex reasoning_content_regex("((?:)?([\\s\\S\\r\\n]*?))?([\\s\\S\\r\\n]*)"); - if (std::regex_match(input, match, reasoning_content_regex)) { - auto rest = match[3].str(); - auto msg = rest_parser(rest); - auto reasoning_content = string_strip(match[2].str()); - if (extract_reasoning) { - msg.reasoning_content = reasoning_content; - } else if (!reasoning_content.empty()) { - std::ostringstream content; - content << "" << reasoning_content << "" << msg.content; - msg.content = content.str(); - } - return msg; +static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; } - return rest_parser(input); -} -static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) { - return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) { - static const std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n"); - static const std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>"); - static const std::regex tool_calls_regex("[\\s\\r\\n]*(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)([\\s\\S\\r\\n]*?)<|tool▁calls▁end|>"); - common_chat_msg msg; - msg.role = "assistant"; - std::smatch match; - if (std::regex_search(input, match, tool_calls_regex)) { - auto tool_calls = match[1].str(); - auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex); - msg.tool_calls = std::move(msg2.tool_calls); - } else { - msg.content = input; - } - return msg; - }); + static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); + static const common_regex tool_calls_end("<|tool▁calls▁end|>"); + static const common_regex function_regex("(?:<|tool▁call▁begin|>)?function<|tool▁sep|>([^\n]+)\n```json\n"); + static const common_regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>"); + + parse_json_tool_calls( + builder, + /* block_open= */ tool_calls_begin, + /* function_regex_start_only= */ std::nullopt, + function_regex, + close_regex, + tool_calls_end); } static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { @@ -1206,13 +1323,19 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c } return data; } -static common_chat_msg common_chat_parse_firefunction_v2(const std::string & input) { - return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1); +static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + static const common_regex prefix(regex_escape(" functools[")); + parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1); } static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) { // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar + // If the function is python, we also allow raw python code (if the line after `python\n` doesn't start w/ opening `{`), which the model seems to prefer for multiline code. common_chat_params data; data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2; @@ -1226,24 +1349,21 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ std::string name = function.at("name"); auto parameters = function.at("parameters"); builder.resolve_refs(parameters); + std::string args_pattern = "[\\s\\S]*"; auto args_rule = builder.add_schema(name + "-args", parameters); - first_tool_rules.push_back(builder.add_rule(name + "-call", "( \"assistant<|end_header_id|>\\n\" )? \"" + name + "\\n\" " + args_rule)); - subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule)); + if (name == "python") { + args_rule = builder.add_rule(name + "-maybe-raw-args", args_rule + " | [^{] .*"); + } else { + args_pattern = "\\{" + args_pattern; + } + auto call_rule = builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule); + first_tool_rules.push_back(call_rule); + if (inputs.parallel_tool_calls) { + subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>\" " + call_rule)); + } data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START, - regex_escape(name + "\n"), - }); - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START, - regex_escape("assistant<|end_header_id|>\n" + name + "\n"), - }); - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_WORD, - regex_escape(">>>" + name + "\n"), - }); - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_WORD, - ">>>assistant<|end_header_id|>\n" + name, + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, + "((?:[\\s\\S]+?>>>)?" + regex_escape(name) + "\n)" + args_pattern, }); }); data.preserved_tokens = { @@ -1261,40 +1381,33 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ } return data; } +static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) { + static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))"); + static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))"); + static const common_regex close_regex(R"(\s*)"); -static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) { - static const std::regex function_regex(R"((?:>>>)?(?:assistant<|end_header_id|>\n)?(\w+)\n)"); - static const std::regex close_regex(R"($|(?=>>>))"); - - std::string content; - auto it = input.begin(); - const auto end = input.end(); - - if (parse_literal(it, end, "all\n")) { - std::smatch match; - if (std::regex_search(it, end, match, function_regex)) { - auto fun_it = match.prefix().second; - content = std::string(it, fun_it); - it = fun_it; - } else { - common_chat_msg res; - res.role = "assistant"; - res.content = std::string(it, end); - return res; - } - } - // TODO: tighten & simplify. - try { - auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex, /* allow_raw_python= */ true); - res.content = content + res.content; - return res; - } catch (const std::exception & e) { - LOG_ERR("Failed to parse functionary v3.2 input: %s\n", e.what()); - common_chat_msg res; - res.role = "assistant"; - res.content = input; - return res; - } + parse_json_tool_calls( + builder, + std::nullopt, + function_regex_start_only, + function_regex, + close_regex, + std::nullopt, + /* allow_raw_python= */ true, + /* get_function_name= */ [&](const auto & res) -> std::string { + auto at_start = res.groups[0].begin == 0; + auto name = builder.str(res.groups[1]); + if (!name.empty() && name.back() == '{') { + // Unconsume the opening brace '{' to ensure the JSON parsing goes well. + builder.move_back(1); + } + auto idx = name.find_last_not_of("\n{"); + name = name.substr(0, idx + 1); + if (at_start && name == "all") { + return ""; + } + return name; + }); } static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) { @@ -1355,229 +1468,224 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con // TODO: if (has_raw_python) return data; } -static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) { - // This version of Functionary still supports the llama 3.1 tool call format for the python tool. - static const std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); - std::smatch match; - if (std::regex_search(input, match, python_tag_regex)) { - auto code = match[1].str(); - common_chat_msg msg; - msg.role = "assistant"; - msg.content = match.prefix().str(); - msg.tool_calls.push_back({ - /* .name = */ "python", - /* .arguments = */ (json {{"code", code}}).dump(), - /* .id = */ "", - }); - return msg; +static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + // This version of Functionary still supports the llama 3.1 tool call format for the python tool. + static const common_regex python_tag_regex(regex_escape("<|python_tag|>")); + + static const common_regex function_regex(R"()"); + static const common_regex close_regex(R"()"); + + parse_json_tool_calls( + builder, + /* block_open= */ std::nullopt, + /* function_regex_start_only= */ std::nullopt, + function_regex, + close_regex, + std::nullopt); + + if (auto res = builder.try_find_regex(python_tag_regex)) { + auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); + builder.add_tool_call("python", "", arguments); + return; } - static const std::regex function_regex(R"()"); - static const std::regex close_regex(R"()"); - // TODO: tighten & simplify. - return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex); } static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; - // (content)?({"name": "foo", "arguments": {"a": 1}})* - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - std::vector tool_call_alts; - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - builder.resolve_refs(parameters); - tool_rules.push_back(builder.add_schema(name + "-call", { - {"type", "object"}, - {"properties", json { - {"name", json {{"const", name}}}, - {"arguments", parameters}, - }}, - {"required", json::array({"name", "arguments"})}, - })); - tool_call_alts.push_back(builder.add_rule( - name + "-function-tag", - "\"\" space " + - builder.add_schema(name + "-args", parameters) + " " - "\"\" space")); - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_WORD, - "", - }); - auto escaped_name = regex_escape(name); - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, - " alt_tags { - any_tool_call, - "\"\" space " + any_tool_call + " \"\"", - // The rest is just to accommodate common "good bad" outputs. - "\"\" space " + any_tool_call + " \"\"", - "\"\" space " + any_tool_call + " \"\"", - "\"\" space " + any_tool_call + " \"\"", - "\"\" space " + any_tool_call + " \"\"", - "\"\" space " + any_tool_call + " \"\"", - "\"\" space " + any_tool_call + " \"\"", - }; - auto wrappable_tool_call = builder.add_rule("wrappable_tool_call", "( " + string_join(alt_tags, " | ") + " ) space"); - tool_call_alts.push_back(wrappable_tool_call); - tool_call_alts.push_back( - "( \"```\\n\" | \"```json\\n\" | \"```xml\\n\" ) space " + wrappable_tool_call + " space \"```\" space "); - auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | ")); - builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, ""}); - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "|||)?\\s*\\{\\s*\"", //name\"\\s*:\\s*\"" + escaped_name + "\"", - }); - data.preserved_tokens = { - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "```", - "```json", - "```xml", - }; - }); + json additional_context = { + {"enable_thinking", inputs.enable_thinking}, + }; + + data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, additional_context); + data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO; + if (string_ends_with(data.prompt, "\n")) { + if (!inputs.enable_thinking) { + data.prompt += ""; + } else { + data.thinking_forced_open = true; + } + } + + if (!inputs.tools.is_null()) { + // (content)?({"name": "foo", "arguments": {"a": 1}})* + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + std::vector tool_call_alts; + std::vector escaped_names; + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); + tool_rules.push_back(builder.add_schema(name + "-call", { + {"type", "object"}, + {"properties", json { + {"name", json {{"const", name}}}, + {"arguments", parameters}, + }}, + {"required", json::array({"name", "arguments"})}, + })); + tool_call_alts.push_back(builder.add_rule( + name + "-function-tag", + "\"\" space " + + builder.add_schema(name + "-args", parameters) + " " + "\"\" space")); + + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_WORD, + "", + }); + auto escaped_name = regex_escape(name); + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, + " alt_tags { + any_tool_call, + "\"\" space " + any_tool_call + " \"\"", + // The rest is just to accommodate common "good bad" outputs. + "\"\" space " + any_tool_call + " \"\"", + "\"\" space " + any_tool_call + " \"\"", + "\"\" space " + any_tool_call + " \"\"", + "\"\" space " + any_tool_call + " \"\"", + "\"\" space " + any_tool_call + " \"\"", + "\"\" space " + any_tool_call + " \"\"", + }; + auto wrappable_tool_call = builder.add_rule("wrappable_tool_call", "( " + string_join(alt_tags, " | ") + " ) space"); + tool_call_alts.push_back(wrappable_tool_call); + tool_call_alts.push_back( + "( \"```\\n\" | \"```json\\n\" | \"```xml\\n\" ) space " + wrappable_tool_call + " space \"```\" space "); + auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | ")); + builder.add_rule("root", + std::string(data.thinking_forced_open ? "( \"\" space )? " : "") + + (inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call)); + // Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives) + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, + // If thinking_forced_open, then we capture the tag in the grammar, + // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) + std::string(data.thinking_forced_open ? "[\\s\\S]*?(\\s*)" : "(?:[\\s\\S]*?\\s*)?") + ( + "(\\s*" + "(?:" + "||||)?" + "\\s*\\{\\s*\"name\"\\s*:\\s*\"(?:" + string_join(escaped_names, "|") + ")\"" + ")" + ")[\\s\\S]*" + ), + }); + data.preserved_tokens = { + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "```", + "```json", + "```xml", + }; + }); + } - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); - data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING : COMMON_CHAT_FORMAT_HERMES_2_PRO; return data; } -static common_chat_msg common_chat_parse_hermes_2_pro(const std::string& input, bool extract_reasoning) { - return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) { - static const std::regex open_regex( - "(?:" - "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start) - "(" // match 2 (open_tag) - "|" - "|" - "|" - "|" - "|" - "|" - "|" +static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex open_regex( + "(?:" + "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start) + "(" // match 2 (open_tag) + "" + "|" + "|" + "|" + "|" + "|" + "|" + "|" ")?" - "(\\s*\\{\\s*\"name\"\\s*:[\\s\\S]*)" // match 3 (named tool call + rest) - ")" - "|" - "(?:]+)>" // match 4 (function name) - "|)" // match 5 (function name again) - "([\\s\\S]*)" // match 6 (function arguments + rest)})" - ); + "(\\s*\\{\\s*\"name\")" // match 3 (named tool call) + ")" + "|]+)>" // match 4 (function name) + "|" // match 5 (function name again) + ); - try { - common_chat_msg msg; - msg.role = "assistant"; + if (auto res = builder.try_find_regex(open_regex)) { + const auto & block_start = res->groups[1]; + std::string block_end = block_start.empty() ? "" : "```"; - std::string::const_iterator it = input.begin(); - const std::string::const_iterator end = input.end(); - std::smatch match; + const auto & open_tag = res->groups[2]; + std::string close_tag; - while (it != end) { - if (std::regex_search(it, end, match, open_regex)) { - // Add content before the match - msg.content += std::string(it, match[0].first); + if (!res->groups[3].empty()) { + builder.move_to(res->groups[3].begin); + close_tag = open_tag.empty() ? "" : "value) || tool_call->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_spaces(); + builder.consume_literal(close_tag); + builder.consume_spaces(); + if (!block_end.empty()) { + builder.consume_literal(block_end); + builder.consume_spaces(); + } + builder.add_content(builder.consume_rest()); + } else { + throw common_chat_msg_partial_exception("failed to parse tool call"); + } + } else { + auto function_name = builder.str(res->groups[4]); + if (function_name.empty()) { + function_name = builder.str(res->groups[5]); + } + GGML_ASSERT(!function_name.empty()); - auto open_tag = match[2].str(); - std::string close_tag; + close_tag = ""; - if (match[3].matched) { - close_tag = open_tag.empty() ? "" : ""; - // Start parsing from after the opening tags - auto json_it = match[6].first; - json arguments; - if (parse_json(json_it, end, arguments)) { - msg.tool_calls.emplace_back(process_tool_call({ - {"name", function_name}, - {"arguments", arguments}, - })); - it = json_it; // Move iterator past parsed JSON - - // Handle close tags - consume_spaces(it, end); - if (!close_tag.empty() && !parse_literal(it, end, close_tag)) { - throw std::runtime_error("Failed to parse closing tag"); - } - consume_spaces(it, end); - if (!block_end.empty() && !parse_literal(it, end, block_end)) { - throw std::runtime_error("Failed to parse block end"); - } - consume_spaces(it, end); - } else { - // Not a valid tool call, treat as content - msg.content += std::string(match[0].first, match[0].second); - it = match[0].second; - } - } - } else { - // Add remaining content - msg.content += std::string(it, end); - break; + if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { + if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_spaces(); + builder.consume_literal(close_tag); + builder.consume_spaces(); + if (!block_end.empty()) { + builder.consume_literal(block_end); + builder.consume_spaces(); } } - return msg; - } catch (const std::exception & e) { - LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what()); - common_chat_msg msg; - msg.role = "assistant"; - msg.content = input; - return msg; + builder.add_content(builder.consume_rest()); } - }); + } else { + builder.add_content(builder.consume_rest()); + } } static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { @@ -1609,8 +1717,8 @@ static common_chat_params common_chat_templates_apply_jinja( const auto & caps = tmpl.original_caps(); params.messages = common_chat_msgs_to_json_oaicompat(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content); params.add_generation_prompt = inputs.add_generation_prompt; - params.extract_reasoning = inputs.extract_reasoning; params.tool_choice = inputs.tool_choice; + params.enable_thinking = inputs.enable_thinking; params.grammar = inputs.grammar; params.now = inputs.now; if (!inputs.json_schema.empty()) { @@ -1644,7 +1752,7 @@ static common_chat_params common_chat_templates_apply_jinja( } // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools) - if (src.find("") != std::string::npos && params.json_schema.is_null() && params.tools.is_array() && params.json_schema.is_null()) { + if (src.find("") != std::string::npos && params.json_schema.is_null()) { return common_chat_params_init_hermes_2_pro(tmpl, params); } @@ -1758,44 +1866,64 @@ common_chat_params common_chat_templates_apply( : common_chat_templates_apply_legacy(tmpls, inputs); } -static common_chat_msg common_chat_parse_content_only(const std::string & input) { - common_chat_msg msg; - msg.role = "assistant"; - msg.content = input; - return msg; +static void common_chat_parse_content_only(common_chat_msg_parser & builder) { + builder.add_content(builder.consume_rest()); } -common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) { - switch (format) { +static void common_chat_parse(common_chat_msg_parser & builder) { + LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str()); + + switch (builder.syntax().format) { case COMMON_CHAT_FORMAT_CONTENT_ONLY: - return common_chat_parse_content_only(input); + common_chat_parse_content_only(builder); + break; case COMMON_CHAT_FORMAT_GENERIC: - return common_chat_parse_generic(input); + common_chat_parse_generic(builder); + break; case COMMON_CHAT_FORMAT_MISTRAL_NEMO: - return common_chat_parse_mistral_nemo(input); + common_chat_parse_mistral_nemo(builder); + break; case COMMON_CHAT_FORMAT_LLAMA_3_X: - return common_chat_parse_llama_3_1(input); + common_chat_parse_llama_3_1(builder); + break; case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: - return common_chat_parse_llama_3_1(input, /* with_builtin_tools= */ true); + common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true); + break; case COMMON_CHAT_FORMAT_DEEPSEEK_R1: - return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ false); - case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING: - return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ true); + common_chat_parse_deepseek_r1(builder); + break; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: - return common_chat_parse_functionary_v3_2(input); + common_chat_parse_functionary_v3_2(builder); + break; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: - return common_chat_parse_functionary_v3_1_llama_3_1(input); + common_chat_parse_functionary_v3_1_llama_3_1(builder); + break; case COMMON_CHAT_FORMAT_HERMES_2_PRO: - return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ false); - case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING: - return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ true); + common_chat_parse_hermes_2_pro(builder); + break; case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: - return common_chat_parse_firefunction_v2(input); + common_chat_parse_firefunction_v2(builder); + break; case COMMON_CHAT_FORMAT_COMMAND_R7B: - return common_chat_parse_command_r7b(input, /* extract_reasoning= */ false); - case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: - return common_chat_parse_command_r7b(input, /* extract_reasoning= */ true); + common_chat_parse_command_r7b(builder); + break; default: - throw std::runtime_error("Unsupported format: " + common_chat_format_name(format)); + throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); } + builder.finish(); +} + +common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) { + common_chat_msg_parser builder(input, is_partial, syntax); + try { + common_chat_parse(builder); + } catch (const common_chat_msg_partial_exception & ex) { + LOG_DBG("Partial parse: %s\n", ex.what()); + if (!is_partial) { + throw std::runtime_error(ex.what()); + } + } + auto msg = builder.result(); + LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str()); + return msg; } diff --git a/common/chat.h b/common/chat.h index d26a09c2f..f6b1d0ffc 100644 --- a/common/chat.h +++ b/common/chat.h @@ -3,6 +3,7 @@ #pragma once #include "common.h" +#include #include #include #include @@ -13,11 +14,19 @@ struct common_chat_tool_call { std::string name; std::string arguments; std::string id; + + bool operator==(const common_chat_tool_call & other) const { + return name == other.name && arguments == other.arguments && id == other.id; + } }; struct common_chat_msg_content_part { std::string type; std::string text; + + bool operator==(const common_chat_msg_content_part & other) const { + return type == other.type && text == other.text; + } }; struct common_chat_msg { @@ -28,6 +37,51 @@ struct common_chat_msg { std::string reasoning_content; std::string tool_name; std::string tool_call_id; + + template T to_json_oaicompat() const; + + bool empty() const { + return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty(); + } + void ensure_tool_call_ids_set(std::vector & ids_cache, const std::function & gen_tool_call_id) { + for (auto i = 0u; i < tool_calls.size(); i++) { + if (ids_cache.size() <= i) { + auto id = tool_calls[i].id; + if (id.empty()) { + id = gen_tool_call_id(); + } + ids_cache.push_back(id); + } + tool_calls[i].id = ids_cache[i]; + } + } + bool operator==(const common_chat_msg & other) const { + return role == other.role + && content == other.content + && content_parts == other.content_parts + && tool_calls == other.tool_calls + && reasoning_content == other.reasoning_content + && tool_name == other.tool_name + && tool_call_id == other.tool_call_id; + } + bool operator!=(const common_chat_msg & other) const { + return !(*this == other); + } +}; + +struct common_chat_msg_diff { + // std::string reasoning_content_delta; + std::string content_delta; + size_t tool_call_index = std::string::npos; + common_chat_tool_call tool_call_delta; + + static std::vector compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg); + + bool operator==(const common_chat_msg_diff & other) const { + return content_delta == other.content_delta + && tool_call_index == other.tool_call_index + && tool_call_delta == other.tool_call_delta; + } }; struct common_chat_tool { @@ -49,14 +103,11 @@ enum common_chat_format { COMMON_CHAT_FORMAT_LLAMA_3_X, COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, COMMON_CHAT_FORMAT_DEEPSEEK_R1, - COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, COMMON_CHAT_FORMAT_FIREFUNCTION_V2, COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, COMMON_CHAT_FORMAT_HERMES_2_PRO, - COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING, COMMON_CHAT_FORMAT_COMMAND_R7B, - COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats }; @@ -71,7 +122,8 @@ struct common_chat_templates_inputs { std::vector tools; common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; bool parallel_tool_calls = false; - bool extract_reasoning = true; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; + bool enable_thinking = true; std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); }; @@ -80,11 +132,21 @@ struct common_chat_params { std::string prompt; std::string grammar; bool grammar_lazy = false; + bool thinking_forced_open = false; std::vector grammar_triggers; std::vector preserved_tokens; std::vector additional_stops; }; +struct common_chat_syntax { + common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; + // Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode) + bool reasoning_in_content = false; + bool thinking_forced_open = false; + bool parse_tool_calls = true; +}; + // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); @@ -121,8 +183,9 @@ std::string common_chat_format_example( const struct common_chat_templates * tmpls, bool use_jinja); -std::string common_chat_format_name(common_chat_format format); -common_chat_msg common_chat_parse( const std::string & input, common_chat_format format); +const char* common_chat_format_name(common_chat_format format); +const char* common_reasoning_format_name(common_reasoning_format format); +common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax); common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice); @@ -135,3 +198,5 @@ template T common_chat_msgs_to_json_oaicompat(const std::vector std::vector common_chat_tools_parse_oaicompat(const T & tools); template T common_chat_tools_to_json_oaicompat(const std::vector & tools); + +template T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff); diff --git a/common/common.cpp b/common/common.cpp index eb16055ea..4cc40ed8b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -203,6 +203,7 @@ bool set_process_priority(enum ggml_sched_priority prio) { DWORD p = NORMAL_PRIORITY_CLASS; switch (prio) { + case GGML_SCHED_PRIO_LOW: p = BELOW_NORMAL_PRIORITY_CLASS; break; case GGML_SCHED_PRIO_NORMAL: p = NORMAL_PRIORITY_CLASS; break; case GGML_SCHED_PRIO_MEDIUM: p = ABOVE_NORMAL_PRIORITY_CLASS; break; case GGML_SCHED_PRIO_HIGH: p = HIGH_PRIORITY_CLASS; break; @@ -228,6 +229,7 @@ bool set_process_priority(enum ggml_sched_priority prio) { int p = 0; switch (prio) { + case GGML_SCHED_PRIO_LOW: p = 5; break; case GGML_SCHED_PRIO_NORMAL: p = 0; break; case GGML_SCHED_PRIO_MEDIUM: p = -5; break; case GGML_SCHED_PRIO_HIGH: p = -10; break; @@ -849,7 +851,7 @@ std::string fs_get_cache_directory() { if (getenv("LLAMA_CACHE")) { cache_directory = std::getenv("LLAMA_CACHE"); } else { -#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) +#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__) if (std::getenv("XDG_CACHE_HOME")) { cache_directory = std::getenv("XDG_CACHE_HOME"); } else { @@ -903,13 +905,16 @@ struct common_init_result common_init_from_params(common_params & params) { ok = false; } - if (llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) { - LOG_WRN("%s: warning: vocab does not have an EOS token, reranking will not work\n", __func__); - ok = false; - } + bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; + bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL; - if (llama_vocab_sep(vocab) == LLAMA_TOKEN_NULL) { - LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__); + if (!has_eos && !has_sep) { + LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__); + ok = false; + } else if (!has_eos) { + LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__); + } else if (!has_sep) { + LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__); ok = false; } diff --git a/common/common.h b/common/common.h index 556ff5be4..cee1e3039 100644 --- a/common/common.h +++ b/common/common.h @@ -76,7 +76,7 @@ enum llama_example { LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, - LLAMA_EXAMPLE_LLAVA, + LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_PARALLEL, LLAMA_EXAMPLE_TTS, @@ -115,7 +115,7 @@ enum common_grammar_trigger_type { COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN, COMMON_GRAMMAR_TRIGGER_TYPE_WORD, COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START, + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, }; struct common_grammar_trigger { @@ -291,6 +291,7 @@ struct common_params { int32_t verbosity = 0; int32_t control_vector_layer_start = -1; // layer range for control vector int32_t control_vector_layer_end = -1; // layer range for control vector + bool offline = false; int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line @@ -368,6 +369,7 @@ struct common_params { bool use_jinja = false; // NOLINT bool enable_chat_template = true; common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + int reasoning_budget = -1; bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response std::vector api_keys; diff --git a/common/json-partial.cpp b/common/json-partial.cpp new file mode 100644 index 000000000..d9d916998 --- /dev/null +++ b/common/json-partial.cpp @@ -0,0 +1,256 @@ +#include "json-partial.h" + +#include "log.h" + +#include + +#include + +using json = nlohmann::ordered_json; + +enum common_json_stack_element_type { + COMMON_JSON_STACK_ELEMENT_OBJECT, + COMMON_JSON_STACK_ELEMENT_KEY, + COMMON_JSON_STACK_ELEMENT_ARRAY, +}; + +struct common_json_stack_element { + common_json_stack_element_type type; + std::string key; +}; + +bool common_json_parse( + const std::string & input, + const std::string & healing_marker, + common_json & out) +{ + std::string::const_iterator it = input.begin(); + const auto end = input.end(); + return common_json_parse(it, end, healing_marker, out); +} + +bool common_json_parse( + std::string::const_iterator & it, + const std::string::const_iterator & end, + const std::string & healing_marker, + common_json & out) +{ + // // https://json.nlohmann.me/features/parsing/sax_interface/ + struct json_error_locator : public nlohmann::json_sax { + std::size_t position; + bool found_error; + std::string last_token; + std::string exception_message; + std::vector stack; + + json_error_locator() : position(0), found_error(false) {} + + bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT + this->position = position - 1; + this->found_error = true; + this->last_token = last_token; + this->exception_message = ex.what(); + return false; + } + void close_value() { + if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) { + stack.pop_back(); + } + } + bool null() override { // NOLINT + close_value(); + return true; + } + bool boolean(bool) override { // NOLINT + close_value(); + return true; + } + bool number_integer(number_integer_t) override { // NOLINT + close_value(); + return true; + } + bool number_unsigned(number_unsigned_t) override { // NOLINT + close_value(); + return true; + } + bool number_float(number_float_t, const string_t &) override { // NOLINT + close_value(); + return true; + } + bool string(string_t &) override { // NOLINT + close_value(); + return true; + } + bool binary(binary_t &) override { // NOLINT + close_value(); + return true; + } + bool start_object(std::size_t) override { // NOLINT + stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""}); + return true; + } + bool end_object() override { + GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT); + stack.pop_back(); + close_value(); + return true; + } + bool key(string_t & key) override { // NOLINT + stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key}); + return true; + } + bool start_array(std::size_t) override { // NOLINT + stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""}); + return true; + } + bool end_array() override { + GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY); + stack.pop_back(); + close_value(); + return true; + } + }; + json_error_locator err_loc; + auto start = it; + json::sax_parse(it, end, &err_loc); + + if (err_loc.found_error) { + it = start; + auto temptative_end = it + err_loc.position; + // LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str()); + + auto input = std::string(it, temptative_end); + try { + out.json = json::parse(input); + // out.json = json::parse(it, temptative_end); + it = temptative_end; + return true; + } catch (const std::exception & ex) { + // No, needs healing. + LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str()); + } + auto can_parse = [](const std::string & str) { + try { + auto _ = json::parse(str); // NOLINT + return true; + } catch (const std::exception &) { + return false; + } + }; + if (!healing_marker.empty() && !err_loc.stack.empty()) { + std::string str(it, temptative_end); + auto last_non_sp_pos = str.find_last_not_of(" \n\r\t"); + if (last_non_sp_pos == std::string::npos) { + throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location"); + } + auto last_non_sp_char = str[last_non_sp_pos]; + // Used to detect stops on a number, which may not be complete. + auto was_maybe_number = [&]() { + if (!str.empty() && std::isspace(str.back())) { + return false; + } + return std::isdigit(last_non_sp_char) || + last_non_sp_char == '.' || + last_non_sp_char == 'e' || + last_non_sp_char == 'E' || + last_non_sp_char == '-'; + }; + + std::string closing; + for (size_t i = err_loc.stack.size(); i > 0; i--) { + auto & el = err_loc.stack[i - 1]; + if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) { + closing += "}"; + } else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) { + closing += "]"; + } else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) { + throw std::runtime_error("Unexpected stack element type"); + } + } + + const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$"; + + if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) { + // We're inside an object value + if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) { + // Was about to create an object value + str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; + } else if (can_parse(str + ": 1" + closing)) { + str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing; + } else if (last_non_sp_char == '{' && can_parse(str + closing)) { + // Was about to create an object + str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing; + } else if (can_parse(str + "\"" + closing)) { + // Was inside an object value string + str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing; + } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) { + // Was inside an object value string after an escape + str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing; + } else { + // find last : + auto last_pos = str.find_last_of(':'); + if (last_pos == std::string::npos) { + throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location"); + } + // Cutting back to opening : for object value + str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; + } + } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) { + if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) { + // Was about to create an array value + str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; + } else if (can_parse(str + "\"" + closing)) { + // Was inside an array value string + str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing; + } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) { + // Was inside an array value string after an escape + str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing; + } else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) { + // Had just finished a value + str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing; + } else { + auto last_pos = str.find_last_of("[,"); + if (last_pos == std::string::npos) { + throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location"); + } + // Cutting back to last [ or , for array value + str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; + } + } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) { + if ((last_non_sp_char == '{' && can_parse(str + closing)) || + (last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) { + // Was about to create an object key+value + str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing; + } else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) { + // Was about to create an object key+value + str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing; + } else if (can_parse(str + "\": 1" + closing)) { + // Was inside an object key string + str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing; + } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) { + // Was inside an object key string after an escape + str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing; + } else { + auto last_pos = str.find_last_of(':'); + if (last_pos == std::string::npos) { + throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location"); + } + // fprintf(stderr, "Cutting back to last : for object key+value\n"); + str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; + } + } else { + throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location"); + } + // fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str()); + out.json = json::parse(str); + it = temptative_end; + return true; + } + // TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...) + // fprintf(stderr, "Closing: TODO\n"); + return false; + } + out.json = json::parse(it, end); + it = end; + return true; +} diff --git a/common/json-partial.h b/common/json-partial.h new file mode 100644 index 000000000..f63356dc4 --- /dev/null +++ b/common/json-partial.h @@ -0,0 +1,38 @@ +#pragma once + +#include + +// Healing marker (empty if the JSON was fully parsed / wasn't healed). +struct common_healing_marker { + // Raw marker. + std::string marker; + + // Cutting the `common_json.json.dump()` string at the (only) occurrence of this marker should yield the original partial JSON string (modulo spaces / if it had the same dump format). + std::string json_dump_marker; +}; + +// Represents a parsed JSON object, with its optional healing marker (a JSON dump fragment that can be used to find the position of healing in the JSON dump string) +struct common_json { + nlohmann::ordered_json json; + + common_healing_marker healing_marker; +}; + +// Parse the JSON string, healing (closing) any partial JSON if `healing_marker` is not empty. +// +// Healing completes partial JSON strings by adding a (possibly modified) healing marker, then whatever is needed to close the JSON. +// This allows to parse the resulting healed JSON string, yet be able to cut it again if needed at the healing marker. +// (this is used when parsing JSON outputs from the models, then crafting partial JSONs for the partial tool calls in OAI format). +// +// For instance, parsing `{` with a healing marker `foo` will produce a healed JSON `{"foo":1}`, w/ json_dump_marker = `"foo"` (which can be used to break the JSON again). +bool common_json_parse( + const std::string & input, + const std::string & healing_marker, + common_json & out); + +// Parse the JSON string (see overload above), but advancing an iterator to the end of the input when the (potentially partial) parsing succeeds. +bool common_json_parse( + std::string::const_iterator & it, + const std::string::const_iterator & end, + const std::string & healing_marker, + common_json & out); diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 5b3059c2f..d38a74f95 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -1,8 +1,9 @@ #include "json-schema-to-grammar.h" #include "common.h" +#include + #include -#include #include #include #include diff --git a/common/json-schema-to-grammar.h b/common/json-schema-to-grammar.h index 4613f5d9f..362991b54 100644 --- a/common/json-schema-to-grammar.h +++ b/common/json-schema-to-grammar.h @@ -1,9 +1,9 @@ #pragma once -#include "ggml.h" -// Change JSON_ASSERT from assert() to GGML_ASSERT: -#define JSON_ASSERT GGML_ASSERT -#include "json.hpp" +#include + +#include +#include std::string json_schema_to_grammar(const nlohmann::ordered_json & schema, bool force_gbnf = false); diff --git a/common/sampling.cpp b/common/sampling.cpp index 28705e24c..9c04d35fd 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -161,7 +161,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); #endif // LLAMA_USE_LLGUIDANCE } else { - std::vector patterns_at_start; + std::vector trigger_patterns; std::vector patterns_anywhere; std::vector trigger_tokens; for (const auto & trigger : params.grammar_triggers) { @@ -173,10 +173,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co break; } case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: - case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START: { - const auto & pattern = trigger.value; - (trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern); + patterns_anywhere.push_back(trigger.value); + break; + } + case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL: + { + trigger_patterns.push_back(trigger.value); break; } case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN: @@ -190,10 +193,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co } } - std::vector trigger_patterns; - if (!patterns_at_start.empty()) { - trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*"); - } if (!patterns_anywhere.empty()) { trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*"); } diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 753c88e7c..4c566af5f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -45,7 +45,7 @@ class SentencePieceTokenTypes(IntEnum): class ModelType(IntEnum): TEXT = 1 - VISION = 2 + MMPROJ = 2 AnyModel = TypeVar("AnyModel", bound="type[ModelBase]") @@ -54,7 +54,7 @@ AnyModel = TypeVar("AnyModel", bound="type[ModelBase]") class ModelBase: _model_classes: dict[ModelType, dict[str, type[ModelBase]]] = { ModelType.TEXT: {}, - ModelType.VISION: {}, + ModelType.MMPROJ: {}, } dir_model: Path @@ -88,7 +88,7 @@ class ModelBase: small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None): if type(self) is ModelBase or \ type(self) is TextModel or \ - type(self) is VisionModel: + type(self) is MmprojModel: raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") self.dir_model = dir_model @@ -309,6 +309,7 @@ class ModelBase: gguf.MODEL_TENSOR.POSNET_NORM1, gguf.MODEL_TENSOR.POSNET_NORM2, gguf.MODEL_TENSOR.V_ENC_EMBD_POS, + gguf.MODEL_TENSOR.A_ENC_EMBD_POS, ) ) or not new_name.endswith(".weight") @@ -422,23 +423,26 @@ class ModelBase: try: # for security reason, we don't allow loading remote code by default # if a model need remote code, we will fallback to config.json - return AutoConfig.from_pretrained(dir_model, trust_remote_code=False).to_dict() + config = AutoConfig.from_pretrained(dir_model, trust_remote_code=False).to_dict() except Exception as e: logger.warning(f"Failed to load model config from {dir_model}: {e}") logger.warning("Trying to load config.json instead") with open(dir_model / "config.json", "r", encoding="utf-8") as f: config = json.load(f) - if "llm_config" in config: - # rename for InternVL - config["text_config"] = config["llm_config"] - return config + if "llm_config" in config: + # rename for InternVL + config["text_config"] = config["llm_config"] + if "thinker_config" in config: + # rename for Qwen2.5-Omni + config["text_config"] = config["thinker_config"]["text_config"] + return config @classmethod def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]: assert names def func(modelcls: AnyModel) -> AnyModel: - model_type = ModelType.VISION if modelcls.model_arch == gguf.MODEL_ARCH.CLIP_VISION else ModelType.TEXT + model_type = ModelType.MMPROJ if modelcls.model_arch == gguf.MODEL_ARCH.MMPROJ else ModelType.TEXT for name in names: cls._model_classes[model_type][name] = modelcls return modelcls @@ -519,15 +523,15 @@ class TextModel(ModelBase): self.gguf_writer.add_context_length(n_ctx) logger.info(f"gguf: context length = {n_ctx}") - if (n_embd := self.find_hparam(["hidden_size", "n_embd"], optional=True)) is not None: + if (n_embd := self.find_hparam(["hidden_size", "n_embd", "dim"], optional=True)) is not None: self.gguf_writer.add_embedding_length(n_embd) logger.info(f"gguf: embedding length = {n_embd}") - if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None: + if (n_ff := self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"], optional=True)) is not None: self.gguf_writer.add_feed_forward_length(n_ff) logger.info(f"gguf: feed forward length = {n_ff}") - if (n_head := self.find_hparam(["num_attention_heads", "n_head"], optional=True)) is not None: + if (n_head := self.find_hparam(["num_attention_heads", "n_head", "n_heads"], optional=True)) is not None: self.gguf_writer.add_head_count(n_head) logger.info(f"gguf: head count = {n_head}") @@ -670,12 +674,12 @@ class TextModel(ModelBase): if chkhsh == "8aeee3860c56296a157a1fe2fad249ec40aa59b1bb5709f4ade11c4e6fe652ed": # ref: https://huggingface.co/tiiuae/falcon-7b res = "falcon" - if chkhsh == "9d032fcbd5501f4a38150912590928bfb36091efb5df11b8e2124b0390e3fb1e": - # ref: https://huggingface.co/tiiuae/Falcon3-7B-Base - res = "falcon3" if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f": # ref: https://huggingface.co/BAAI/bge-small-en-v1.5 res = "bert-bge" + if chkhsh == "9d032fcbd5501f4a38150912590928bfb36091efb5df11b8e2124b0390e3fb1e": + # ref: https://huggingface.co/tiiuae/Falcon3-7B-Base + res = "falcon3" if chkhsh == "8e62295832751ca1e8f92f2226f403dea30dc5165e448b5bfa05af5340c64ec7": # ref: https://huggingface.co/BAAI/bge-large-zh-v1.5 res = "bert-bge-large" @@ -727,9 +731,6 @@ class TextModel(ModelBase): if chkhsh == "7967bfa498ade6b757b064f31e964dddbb80f8f9a4d68d4ba7998fcf281c531a": # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-code res = "jina-v2-code" - if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b" or chkhsh == "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516": - # ref: https://huggingface.co/THUDM/glm-4-9b-chat - res = "chatglm-bpe" if chkhsh == "7fc505bd3104ca1083b150b17d088b59534ede9bde81f0dd2090967d7fe52cee": # ref: https://huggingface.co/LumiOpen/Viking-7B res = "viking" @@ -760,9 +761,6 @@ class TextModel(ModelBase): if chkhsh == "60824e3c0d9401f89943cbb2fff727f0e2d4c545ba4df2d6e4f09a6db0f5b450": # ref: https://huggingface.co/facebook/chameleon-7b res = "chameleon" - if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35": - # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0 - res = "minerva-7b" if chkhsh == "8b5a93ed704057481f240da0be7e7dca721d7f8f4755263b6807227a2cbeae65": # ref: https://huggingface.co/sentence-transformers/stsb-roberta-base res = "roberta-bpe" @@ -793,15 +791,24 @@ class TextModel(ModelBase): if chkhsh == "d353350c764d8c3b39c763113960e4fb4919bea5fbf208a0e3b22e8469dc7406": # ref: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct res = "llama4" - if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2": - # ref: https://huggingface.co/THUDM/glm-4-9b-hf - res = "glm4" if chkhsh == "0e9433cbbb161f89e264eb32e8e64bfe69e834973ffca5d41d3948a604a3e2a3": # ref: https://huggingface.co/mistral-community/pixtral-12b res = "pixtral" if chkhsh == "d5f1dd6f980fec569fb218a81a7658ac45fc56b38c5a0adeb1c232fbe04ef5ec": # ref: https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base res = "seed-coder" + if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b": + # ref: https://huggingface.co/THUDM/glm-4-9b-chat + res = "chatglm-bpe" + if chkhsh == "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516": + # ref: https://huggingface.co/THUDM/glm-4-9b-chat + res = "chatglm-bpe" + if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2": + # ref: https://huggingface.co/THUDM/glm-4-9b-hf + res = "glm4" + if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35": + # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0 + res = "minerva-7b" if res is None: logger.warning("\n") @@ -1040,6 +1047,10 @@ class TextModel(ModelBase): special_vocab.chat_template = "rwkv-world" # hack: Add '\n\n' as the EOT token to make it chat normally special_vocab._set_special_token("eot", 261) + # hack: Override these as they have already been set (incorrectly) + special_vocab.special_token_ids["bos"] = 0 + special_vocab.special_token_ids["eos"] = 0 + special_vocab.add_to_gguf(self.gguf_writer) def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab_size: int): @@ -1114,60 +1125,116 @@ class TextModel(ModelBase): self.gguf_writer.add_pooling_type(pooling_type) -class VisionModel(ModelBase): - model_type = ModelType.VISION - model_arch = gguf.MODEL_ARCH.CLIP_VISION +class MmprojModel(ModelBase): + model_type = ModelType.MMPROJ + model_arch = gguf.MODEL_ARCH.MMPROJ preprocessor_config: dict[str, Any] global_config: dict[str, Any] + n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"] + + has_vision_encoder: bool = True # by default + has_audio_encoder: bool = False + + # for models having multiple encoders, we need to separate their hparams + hparams_vision: dict[str, Any] | None = None + hparams_audio: dict[str, Any] | None = None + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if self.model_arch != gguf.MODEL_ARCH.CLIP_VISION: - raise TypeError("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION") + if self.model_arch != gguf.MODEL_ARCH.MMPROJ: + raise TypeError("MmprojModel must be subclassed with model_arch = gguf.MODEL_ARCH.MMPROJ") # get n_embd of the text model if "text_config" not in self.hparams: self.hparams["text_config"] = {} + if "audio_config" not in self.hparams: + self.hparams["audio_config"] = {} text_config = {**self.hparams, **self.hparams["text_config"]} self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0)) assert self.n_embd_text > 0, "n_embd not found in hparams" - if "vision_config" not in self.hparams: - raise ValueError("vision_config not found in hparams") # move vision config to the top level, while preserving the original hparams in global_config - self.global_config = self.hparams - self.hparams = self.hparams["vision_config"] + import copy + self.global_config = copy.deepcopy(self.hparams) + self.hparams_vision = self.get_vision_config() + self.hparams_audio = self.get_audio_config() - self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"]) - self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, self.block_count) + if self.hparams_vision is None and self.hparams_audio is None: + raise ValueError("vision_config / audio_config not found in hparams") + + # for compat with vision-only models + self.hparams = self.hparams_vision or self.hparams_audio or self.hparams + + # TODO @ngxson : this is a hack to support both vision and audio encoders + have_multiple_encoders = self.has_audio_encoder and self.has_vision_encoder + self.block_count = 128 if have_multiple_encoders else self.find_hparam(self.n_block_keys, True) + self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count) # load preprocessor config with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f: self.preprocessor_config = json.load(f) + def get_vision_config(self) -> dict[str, Any] | None: + return self.global_config.get("vision_config") + + def get_audio_config(self) -> dict[str, Any] | None: + return self.global_config.get("audio_config") + def set_type(self): - self.gguf_writer.add_type(gguf.GGUFType.CLIP_VISION) + self.gguf_writer.add_type(gguf.GGUFType.MMPROJ) def set_gguf_parameters(self): self.gguf_writer.add_file_type(self.ftype) - self.gguf_writer.add_vision_projection_dim(self.n_embd_text) - self.gguf_writer.add_vision_has_vision_encoder(True) - # vision config - self.gguf_writer.add_vision_image_size(self.find_hparam(["image_size"])) - self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"])) - self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"])) - self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"])) - self.gguf_writer.add_vision_block_count(self.block_count) - self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"])) + if self.has_vision_encoder: + self.gguf_writer.add_clip_has_vision_encoder(True) + self.gguf_writer.add_vision_projection_dim(self.n_embd_text) - # preprocessor config - self.gguf_writer.add_vision_image_mean(self.preprocessor_config["image_mean"]) - self.gguf_writer.add_vision_image_std(self.preprocessor_config["image_std"]) + # vision config + self.gguf_writer.add_vision_image_size(self.find_vparam(["image_size"])) + self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"])) + self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"])) + self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"])) + self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys)) + self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads"])) + + # preprocessor config + self.gguf_writer.add_vision_image_mean(self.preprocessor_config["image_mean"]) + self.gguf_writer.add_vision_image_std(self.preprocessor_config["image_std"]) + + if self.has_audio_encoder: + self.gguf_writer.add_clip_has_audio_encoder(True) + self.gguf_writer.add_audio_projection_dim(self.n_embd_text) + + # audio config + self.gguf_writer.add_audio_embedding_length(self.find_aparam(["hidden_size"])) + self.gguf_writer.add_audio_feed_forward_length(self.find_aparam(["intermediate_size"])) + self.gguf_writer.add_audio_block_count(self.find_aparam(self.n_block_keys)) + self.gguf_writer.add_audio_head_count(self.find_aparam(["num_attention_heads"])) + + if not self.has_vision_encoder and not self.has_audio_encoder: + raise ValueError("MmprojModel must have either vision or audio encoder") def write_vocab(self): - raise ValueError("VisionModel does not support vocab writing") + raise ValueError("MmprojModel does not support vocab writing") + + def find_vparam(self, keys: Iterable[str], optional: bool = False) -> Any: + assert self.hparams_vision is not None + return self._find_param(self.hparams_vision, keys, optional) + + def find_aparam(self, keys: Iterable[str], optional: bool = False) -> Any: + assert self.hparams_audio is not None + return self._find_param(self.hparams_audio, keys, optional) + + def _find_param(self, obj: dict[str, Any], keys: Iterable[str], optional: bool = False) -> Any: + key = next((k for k in keys if k in obj), None) + if key is not None: + return obj[key] + if optional: + return None + raise KeyError(f"could not find any of: {keys}") @ModelBase.register("GPTNeoXForCausalLM") @@ -1781,7 +1848,8 @@ class StableLMModel(TextModel): "MistralForCausalLM", "MixtralForCausalLM", "VLlama3ForCausalLM", - "LlavaForConditionalGeneration") + "LlavaForConditionalGeneration", + "LlamaModel") class LlamaModel(TextModel): model_arch = gguf.MODEL_ARCH.LLAMA undo_permute = True @@ -1861,6 +1929,8 @@ class LlamaModel(TextModel): if is_vision_tensor: return [] # skip vision tensors + elif self.hf_arch == "LlamaModel": + name = "model." + name elif name.startswith("model.text_model"): name = name.replace("text_model.", "") # for SmolVLM elif name.startswith("language_model."): @@ -1951,7 +2021,7 @@ class LlamaModel(TextModel): "LlavaForConditionalGeneration", # pixtral "Mistral3ForConditionalGeneration", # mistral small 3.1 ) -class LlavaVisionModel(VisionModel): +class LlavaVisionModel(MmprojModel): img_break_tok_id = -1 def __init__(self, *args, **kwargs): @@ -1977,7 +2047,7 @@ class LlavaVisionModel(VisionModel): super().set_gguf_parameters() hparams = self.hparams if hparams["model_type"] == "pixtral": - self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL) + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PIXTRAL) self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"]) # hidden_act @@ -2016,7 +2086,7 @@ class LlavaVisionModel(VisionModel): @ModelBase.register("Idefics3ForConditionalGeneration", "SmolVLMForConditionalGeneration") -class SmolVLMModel(VisionModel): +class SmolVLMModel(MmprojModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.hparams["model_type"] == "smolvlm_vision": @@ -2028,7 +2098,7 @@ class SmolVLMModel(VisionModel): def set_gguf_parameters(self): super().set_gguf_parameters() - self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.IDEFICS3) + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.IDEFICS3) self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5)) self.gguf_writer.add_vision_projector_scale_factor(self.global_config.get("scale_factor", 2)) self.gguf_writer.add_vision_use_gelu(True) @@ -2094,10 +2164,10 @@ class Llama4Model(LlamaModel): @ModelBase.register("Llama4ForConditionalGeneration") -class Llama4VisionModel(VisionModel): +class Llama4VisionModel(MmprojModel): def set_gguf_parameters(self): super().set_gguf_parameters() - self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.LLAMA4) + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LLAMA4) self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams["norm_eps"]) self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / self.hparams["pixel_shuffle_ratio"])) assert self.hparams["hidden_act"] == "gelu" @@ -2109,6 +2179,9 @@ class Llama4VisionModel(VisionModel): # process vision tensors if "positional_embedding_vlm" in name and ".weight" not in name: name += ".weight" + if "multi_modal_projector.linear_1" in name: + # despite the name with number postfix, this is a single fully connected layer + return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MMPROJ_FC], data_torch)] return [(self.map_tensor_name(name), data_torch)] return [] @@ -2615,7 +2688,7 @@ class QwenModel(TextModel): self.gguf_writer.add_file_type(self.ftype) -@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM") +@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration") class Qwen2Model(TextModel): model_arch = gguf.MODEL_ARCH.QWEN2 @@ -2639,13 +2712,19 @@ class Qwen2Model(TextModel): name = f"model.{name}" # map to Qwen2ForCausalLM tensors if "language_model." in name: name = name.replace("language_model.", "") # for InternVL - if name.startswith("mlp") or name.startswith("vision_model"): - # skip visual tensors + if name.startswith("mlp") or name.startswith("multi_modal_projector") \ + or name.startswith("vision_model") or name.startswith("audio_tower"): + # skip vision and audio tensors return [] yield from super().modify_tensors(data_torch, name, bid) -@ModelBase.register("Qwen2VLModel", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration") +@ModelBase.register( + "Qwen2VLModel", + "Qwen2VLForConditionalGeneration", + "Qwen2_5_VLForConditionalGeneration", + "Qwen2_5OmniModel", +) class Qwen2VLModel(TextModel): model_arch = gguf.MODEL_ARCH.QWEN2VL @@ -2663,31 +2742,40 @@ class Qwen2VLModel(TextModel): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused - if name.startswith("visual."): - # skip visual tensors + if name.startswith("thinker."): + name = name.replace("thinker.", "") + if name.startswith("visual") or name.startswith("audio") or \ + name.startswith("talker") or name.startswith("token2wav"): + # skip multimodal tensors return [] return [(self.map_tensor_name(name), data_torch)] @ModelBase.register("Qwen2VLModel", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration") -class Qwen2VLVisionModel(VisionModel): +class Qwen2VLVisionModel(MmprojModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.hparams["image_size"] = self.hparams.get("image_size", 560) + assert self.hparams_vision is not None + self.hparams_vision["image_size"] = self.hparams_vision.get("image_size", 560) # rename config.json values - self.hparams["num_attention_heads"] = self.hparams.get("num_heads") - self.hparams["num_hidden_layers"] = self.hparams.get("depth") - if "embed_dim" in self.hparams: # qwen2vl - self.hparams["intermediate_size"] = self.hparams.get("hidden_size") - self.hparams["hidden_size"] = self.hparams.get("embed_dim") + self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads") + self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth") + if "embed_dim" in self.hparams_vision: # qwen2vl + self.hparams_vision["intermediate_size"] = self.hparams_vision.get("hidden_size") + self.hparams_vision["hidden_size"] = self.hparams_vision.get("embed_dim") def set_gguf_parameters(self): super().set_gguf_parameters() - hparams = self.hparams - if self.global_config['model_type'] == 'qwen2_vl': - self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.QWEN2VL) - elif self.global_config['model_type'] == 'qwen2_5_vl': - self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.QWEN25VL) + assert self.hparams_vision is not None + hparams = self.hparams_vision + model_type = self.global_config['model_type'] + if model_type == 'qwen2_vl': + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN2VL) + elif model_type == 'qwen2_5_vl' or model_type == 'qwen2_5_omni': + if model_type == 'qwen2_5_omni': + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25O) + else: + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25VL) self.gguf_writer.add_vision_use_silu(True) # find n_wa_pattern (window attention pattern) fullatt_block_indexes = hparams.get("fullatt_block_indexes") @@ -2745,12 +2833,72 @@ class Qwen2VLVisionModel(VisionModel): return [] # skip other tensors +@ModelBase.register("Qwen2_5OmniModel") +class Qwen25OmniModel(Qwen2VLVisionModel): + has_vision_encoder = True + has_audio_encoder = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams_audio is not None + self.hparams_audio["hidden_size"] = self.hparams_audio["d_model"] + self.hparams_audio["intermediate_size"] = self.hparams_audio["encoder_ffn_dim"] + self.hparams_audio["num_attention_heads"] = self.hparams_audio["encoder_attention_heads"] + + def set_gguf_parameters(self): + super().set_gguf_parameters() + assert self.hparams_audio is not None + self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["num_mel_bins"]) + self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams_audio.get("layer_norm_eps", 1e-5)) + + def get_vision_config(self) -> dict[str, Any] | None: + return self.global_config["thinker_config"].get("vision_config") + + def get_audio_config(self) -> dict[str, Any] | None: + return self.global_config["thinker_config"].get("audio_config") + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + # SinusoidsPositionEmbedding + assert self.hparams_audio is not None + max_timescale = 10000 + length = 1500 + channels = self.hparams_audio["hidden_size"] + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float()) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + pos_embd = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1).to(dtype=torch.float32) + yield ("audio_tower.embed_positions.weight", pos_embd) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + del bid, new_name, n_dims # unused + if ".conv" in name and ".weight" in name: + return gguf.GGMLQuantizationType.F16 + return False + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.startswith("thinker."): + name = name.replace("thinker.", "") + + if name.startswith("audio_tower"): + # process audio tensors + if "conv1.bias" in name or "conv2.bias" in name: + # transpose conv1 and conv2 bias + data_torch = data_torch.unsqueeze(-1) + if "audio_bos_eos_token" in name: + # this tensor is left unused in transformers code + # https://github.com/huggingface/transformers/blob/6e3063422c4b1c014aa60c32b9254fd2902f0f28/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py#L1809 + return [] + return [(self.map_tensor_name(name), data_torch)] + + return super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("InternVisionModel") -class InternVisionModel(VisionModel): +class InternVisionModel(MmprojModel): def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams - self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.INTERNVL) + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.INTERNVL) self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"]) # hidden_act if hparams["hidden_act"] == "silu": @@ -3541,7 +3689,7 @@ class InternLM3Model(TextModel): return [(self.map_tensor_name(name), data_torch)] -@ModelBase.register("BertModel", "BertForMaskedLM", "CamembertModel") +@ModelBase.register("BertModel", "BertForMaskedLM", "CamembertModel", "BertForSequenceClassification") class BertModel(TextModel): model_arch = gguf.MODEL_ARCH.BERT @@ -3549,11 +3697,21 @@ class BertModel(TextModel): super().__init__(*args, **kwargs) self.vocab_size = None + if cls_out_labels := self.hparams.get("id2label"): + if len(cls_out_labels) == 2 and cls_out_labels[0] == "LABEL_0": + # Remove dummy labels added by AutoConfig + cls_out_labels = None + self.cls_out_labels = cls_out_labels + def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_causal_attention(False) self._try_set_pooling_type() + if self.cls_out_labels: + key_name = gguf.Keys.Classifier.OUTPUT_LABELS.format(arch = gguf.MODEL_ARCH_NAMES[self.model_arch]) + self.gguf_writer.add_array(key_name, [v for k, v in sorted(self.cls_out_labels.items())]) + def set_vocab(self): tokens, toktypes, tokpre = self.get_vocab_base() self.vocab_size = len(tokens) @@ -3604,6 +3762,14 @@ class BertModel(TextModel): if name.startswith("cls.seq_relationship"): return [] + if self.cls_out_labels: + # For BertForSequenceClassification (direct projection layer) + if name == "classifier.weight": + name = "classifier.out_proj.weight" + + if name == "classifier.bias": + name = "classifier.out_proj.bias" + return [(self.map_tensor_name(name), data_torch)] def _xlmroberta_tokenizer_init(self) -> None: @@ -3648,7 +3814,7 @@ class BertModel(TextModel): remove_whitespaces = tokenizer.clean_up_tokenization_spaces precompiled_charsmap = b64decode(tokenizer_json["normalizer"]["precompiled_charsmap"]) - vocab_size = self.hparams.get("vocab_size", tokenizer.vocab_size) + vocab_size = max(self.hparams.get("vocab_size", 0), tokenizer.vocab_size) else: sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) @@ -3661,7 +3827,7 @@ class BertModel(TextModel): tokenizer = SentencePieceProcessor() tokenizer.LoadFromFile(str(tokenizer_path)) - vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + vocab_size = max(self.hparams.get("vocab_size", 0), tokenizer.vocab_size()) tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] scores: list[float] = [-10000.0] * vocab_size @@ -3691,33 +3857,26 @@ class BertModel(TextModel): unk_token = tokenizer_config_json.get("unk_token") unk_token_id = added_vocab.get(unk_token, tokenizer_json["model"].get("unk_id", 3)) - for token_id in range(vocab_size): + for token_id in range(tokenizer.vocab_size): piece = tokenizer._convert_id_to_token(token_id) - text = piece.encode("utf-8") - score = tokenizer_json["model"]["vocab"][token_id][1] + if (piece := tokenizer._convert_id_to_token(token_id)) is not None: + text = piece.encode("utf-8") + score = tokenizer_json["model"]["vocab"][token_id][1] - toktype = SentencePieceTokenTypes.NORMAL - if token_id == unk_token_id: - toktype = SentencePieceTokenTypes.UNKNOWN - elif token_id in tokenizer.all_special_ids: - toktype = SentencePieceTokenTypes.CONTROL - elif token_id in added_vocab.values(): - toktype = SentencePieceTokenTypes.USER_DEFINED - # No reliable way to detect this, but jina-embeddings-v3 doesn't have any - # elif tokenizer.IsByte(token_id): - # toktype = SentencePieceTokenTypes.BYTE + toktype = SentencePieceTokenTypes.NORMAL + if token_id == unk_token_id: + toktype = SentencePieceTokenTypes.UNKNOWN + elif token_id in tokenizer.all_special_ids: + toktype = SentencePieceTokenTypes.CONTROL + elif token_id in added_vocab.values(): + toktype = SentencePieceTokenTypes.USER_DEFINED + # No reliable way to detect this, but jina doesn't have any + # elif tokenizer.IsByte(token_id): + # toktype = SentencePieceTokenTypes.BYTE - tokens[token_id] = text - scores[token_id] = score - toktypes[token_id] = toktype - - if vocab_size > len(tokens): - pad_count = vocab_size - len(tokens) - logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]") - for i in range(1, pad_count + 1): - tokens.append(bytes(f"[PAD{i}]", encoding="utf-8")) - scores.append(-1000.0) - toktypes.append(SentencePieceTokenTypes.UNUSED) + tokens[token_id] = text + scores[token_id] = score + toktypes[token_id] = toktype if isinstance(tokenizer, SentencePieceProcessor): # realign tokens (see HF tokenizer code) @@ -3730,6 +3889,12 @@ class BertModel(TextModel): SentencePieceTokenTypes.UNKNOWN, ] + toktypes[3:-1] + if self.model_arch == gguf.MODEL_ARCH.NOMIC_BERT_MOE: + # Add mask token missing from sentencepiece.bpe.model + tokens[250001] = b'' + scores[250001] = 0.0 + toktypes[250001] = SentencePieceTokenTypes.CONTROL + self.gguf_writer.add_tokenizer_model("t5") self.gguf_writer.add_tokenizer_pre("default") self.gguf_writer.add_token_list(tokens) @@ -3748,7 +3913,27 @@ class BertModel(TextModel): self.gguf_writer.add_add_eos_token(True) -@ModelBase.register("RobertaModel") +@ModelBase.register("DistilBertModel", "DistilBertForMaskedLM", "DistilBertForSequenceClassification") +class DistilBertModel(BertModel): + model_arch = gguf.MODEL_ARCH.BERT + + def set_gguf_parameters(self): + self.gguf_writer.add_layer_norm_eps(1e-12) + logger.info("gguf: layer norm epsilon = 1e-12") + super().set_gguf_parameters() + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.startswith("distilbert."): + name = name[11:] + + # These layers act as MLM head, so we don't need them + if name.startswith("vocab_"): + return [] + + return super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("RobertaModel", "RobertaForSequenceClassification") class RobertaModel(BertModel): model_arch = gguf.MODEL_ARCH.BERT @@ -4086,11 +4271,11 @@ class Gemma3Model(TextModel): @ModelBase.register("Gemma3ForConditionalGeneration") -class Gemma3VisionModel(VisionModel): +class Gemma3VisionModel(MmprojModel): def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams - self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.GEMMA3) + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GEMMA3) # default values below are taken from HF tranformers code self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6)) self.gguf_writer.add_vision_use_gelu(True) @@ -6037,6 +6222,65 @@ class ChameleonModel(TextModel): return data_torch +@ModelBase.register("UltravoxModel") +class UltravoxModel(TextModel): + model_arch = gguf.MODEL_ARCH.LLAMA # dummy + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + raise NotImplementedError("Ultravox does not have text decoder. Instead, it uses Llama or other models for text. If you want to get the audio encoder, please use --mmproj argument") + + +@ModelBase.register("Qwen2AudioForConditionalGeneration") +class WhisperEncoderModel(MmprojModel): + has_vision_encoder = False # no vision encoder + has_audio_encoder = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.hparams["hidden_size"] = self.hparams["d_model"] + self.hparams["intermediate_size"] = self.hparams["encoder_ffn_dim"] + self.hparams["num_attention_heads"] = self.hparams["encoder_attention_heads"] + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN2A) + self.gguf_writer.add_audio_num_mel_bins(self.hparams["num_mel_bins"]) + self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5)) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + del bid, new_name, n_dims # unused + if ".conv" in name and ".weight" in name: + return gguf.GGMLQuantizationType.F16 + return False + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if name.startswith("language_model."): + # skip language model tensors + return [] + + # prevent clash naming with vision tensors + if name.startswith("multi_modal_projector"): + name = "audio." + name + + if "conv1.bias" in name or "conv2.bias" in name: + # transpose conv1 and conv2 bias + data_torch = data_torch.unsqueeze(-1) + + return [(self.map_tensor_name(name), data_torch)] + + +@ModelBase.register("UltravoxModel") +class UltravoxWhisperEncoderModel(WhisperEncoderModel): + has_vision_encoder = False # no vision encoder + has_audio_encoder = True + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"]) + ###### CONVERSION LOGIC ###### @@ -6212,13 +6456,15 @@ def split_str_to_n_bytes(split_str: str) -> int: def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> str: + # TODO @ngxson : this won't work correctly if the model has both audio & vision encoders + # maybe we should fallback to text model's arch in that case, since not many models have both text_config = hparams.get("text_config", {}) vision_config = hparams.get("vision_config", {}) arch = hparams["architectures"][0] # if "architectures" is found in the sub-config, use that instead if model_type == ModelType.TEXT and text_config.get("architectures") is not None: arch = text_config["architectures"][0] - elif model_type == ModelType.VISION and vision_config.get("architectures") is not None: + elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None: arch = vision_config["architectures"][0] return arch @@ -6281,7 +6527,7 @@ def main() -> None: with torch.inference_mode(): output_type = ftype_map[args.outtype] - model_type = ModelType.VISION if args.mmproj else ModelType.TEXT + model_type = ModelType.MMPROJ if args.mmproj else ModelType.TEXT hparams = ModelBase.load_hparams(dir_model) model_architecture = get_model_architecture(hparams, model_type) logger.info(f"Model architecture: {model_architecture}") diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 5993a4c98..2f733f097 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -1,28 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# This script downloads the tokenizer models of the specified models from Huggingface and -# generates the get_vocab_base_pre() function for convert_hf_to_gguf.py -# -# This is necessary in order to analyze the type of pre-tokenizer used by the model and -# provide the necessary information to llama.cpp via the GGUF header in order to implement -# the same pre-tokenizer. -# -# ref: https://github.com/ggml-org/llama.cpp/pull/6920 -# -# Instructions: -# -# - Add a new model to the "models" list -# - Run the script with your huggingface token: -# -# python3 convert_hf_to_gguf_update.py -# -# - The convert_hf_to_gguf.py script will have had its get_vocab_base_pre() function updated -# - Update llama.cpp with the new pre-tokenizer if necessary -# -# TODO: generate tokenizer tests for llama.cpp -# - import logging import os import pathlib @@ -32,6 +10,7 @@ import requests import sys import json import shutil +import argparse from hashlib import sha256 from enum import IntEnum, auto @@ -41,6 +20,11 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger("convert_hf_to_gguf_update") sess = requests.Session() +convert_py_pth = pathlib.Path("convert_hf_to_gguf.py") +convert_py = convert_py_pth.read_text(encoding="utf-8") +hf_token_pth = pathlib.Path.home() / ".cache" / "huggingface" / "token" +hf_token = hf_token_pth.read_text(encoding="utf-8").strip() if hf_token_pth.exists() else None + class TOKENIZER_TYPE(IntEnum): SPM = auto() @@ -49,20 +33,49 @@ class TOKENIZER_TYPE(IntEnum): UGM = auto() +DOC_STRING = """ +This script downloads the tokenizer models of the specified models from Huggingface and +generates the get_vocab_base_pre() function for convert_hf_to_gguf.py + +/!\\ It is intended to be used by contributors and is not meant to be run by end users + +This is necessary in order to analyze the type of pre-tokenizer used by the model and +provide the necessary information to llama.cpp via the GGUF header in order to implement +the same pre-tokenizer. + +ref: https://github.com/ggml-org/llama.cpp/pull/6920 + +Instructions: + +- Add a new model to the "models" list +- Run the script with your huggingface token + By default, token will be read from ~/.cache/huggingface/token +- The convert_hf_to_gguf.py script will have had its get_vocab_base_pre() function updated +- Update llama.cpp with the new pre-tokenizer if necessary +""" +# TODO: generate tokenizer tests for llama.cpp + +parser = argparse.ArgumentParser(description=DOC_STRING, formatter_class=argparse.RawTextHelpFormatter) +parser.add_argument( + "--full", action="store_true", + help="download full list of models - make sure you have access to all of them", +) +parser.add_argument( + "hf_token", + help="optional HF token", + nargs="?", +) +args = parser.parse_args() +hf_token = args.hf_token if args.hf_token is not None else hf_token + +if hf_token is None: + logger.error("HF token is required. Please provide it as an argument or set it in ~/.cache/huggingface/token") + sys.exit(1) + # TODO: this string has to exercise as much pre-tokenizer functionality as possible # will be updated with time - contributions welcome CHK_TXT = '\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български \'\'\'\'\'\'```````\"\"\"\"......!!!!!!?????? I\'ve been \'told he\'s there, \'RE you sure? \'M not sure I\'ll make it, \'D you like some tea? We\'Ve a\'lL' -if len(sys.argv) == 2: - token = sys.argv[1] - if not token.startswith("hf_"): - logger.info("Huggingface token seems invalid") - logger.info("Usage: python convert_hf_to_gguf_update.py ") - sys.exit(1) -else: - logger.info("Usage: python convert_hf_to_gguf_update.py ") - sys.exit(1) - # TODO: add models here, base models preferred models = [ {"name": "llama-spm", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/meta-llama/Llama-2-7b-hf", }, @@ -103,7 +116,6 @@ models = [ {"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", }, {"name": "phi-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/microsoft/phi-2", }, {"name": "chameleon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/facebook/chameleon-7b", }, - {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", }, {"name": "roberta-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sentence-transformers/stsb-roberta-base"}, {"name": "gigachat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct"}, {"name": "megrez", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Infinigence/Megrez-3B-Instruct"}, @@ -114,11 +126,19 @@ models = [ {"name": "trillion", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/trillionlabs/Trillion-7B-preview", }, {"name": "bailingmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-lite", }, {"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", }, - {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", }, {"name": "pixtral", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistral-community/pixtral-12b", }, {"name": "seed-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", }, ] +# some models are known to be broken upstream, so we will skip them as exceptions +pre_computed_hashes = [ + # chatglm-bpe has 2 hashes, why? + {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b"}, + {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"}, + {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"}, + {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"}, +] + def download_file_with_auth(url, token, save_path): headers = {"Authorization": f"Bearer {token}"} @@ -169,9 +189,29 @@ def download_model(model): if os.path.isfile(save_path): logger.info(f"{name}: File {save_path} already exists - skipping") continue - download_file_with_auth(f"{repo}/resolve/main/{file}", token, save_path) + download_file_with_auth(f"{repo}/resolve/main/{file}", hf_token, save_path) +# get list of existing models and chkhsh from the convert_hf_to_gguf.py file +# returns mapping res --> chkhsh +def get_existing_models(convert_py): + pattern = r'if chkhsh == "([a-f0-9]{64})":\s*\n\s*.*\s*res = "([^"]+)"' + matches = re.findall(pattern, convert_py) + output = {} + for chkhsh, res in matches: + output[res] = chkhsh + return output + + +existing_models = {} +all_models = models.copy() +if not args.full: + # Filter out models that already exist in convert_hf_to_gguf.py + existing_models = get_existing_models(convert_py) + all_models = models.copy() + models = [model for model in all_models if model["name"] not in existing_models] + +logging.info(f"Downloading {len(models)} models...") for model in models: try: download_model(model) @@ -182,9 +222,10 @@ for model in models: # generate the source code for the convert_hf_to_gguf.py:get_vocab_base_pre() function: src_ifs = "" -for model in models: +for model in [*all_models, *pre_computed_hashes]: name = model["name"] tokt = model["tokt"] + chkhsh = model.get("chkhsh") if tokt == TOKENIZER_TYPE.SPM or tokt == TOKENIZER_TYPE.UGM: continue @@ -195,35 +236,44 @@ for model in models: continue # create the tokenizer - try: - if name == "t5": - tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False) - else: - tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}") - except OSError as e: - logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}") - continue # Skip to the next model if the tokenizer can't be loaded + if chkhsh is not None: + # if the model has a pre-computed hash, use it + logger.info(f"Using pre-computed hash for model {name}: {chkhsh}") + elif name in existing_models: + # if the model already exists in convert_hf_to_gguf.py, skip compute hash + chkhsh = existing_models[name] + else: + # otherwise, compute the hash of the tokenizer + try: + logger.info(f"Loading tokenizer from {f'models/tokenizers/{name}'}...") + if name == "t5": + tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False) + else: + tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}") + except OSError as e: + logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}") + continue # Skip to the next model if the tokenizer can't be loaded - chktok = tokenizer.encode(CHK_TXT) - chkhsh = sha256(str(chktok).encode()).hexdigest() + chktok = tokenizer.encode(CHK_TXT) + chkhsh = sha256(str(chktok).encode()).hexdigest() - logger.info(f"model: {name}") - logger.info(f"tokt: {tokt}") - logger.info(f"repo: {model['repo']}") - logger.info(f"chktok: {chktok}") - logger.info(f"chkhsh: {chkhsh}") + logger.info(f"model: {name}") + logger.info(f"tokt: {tokt}") + logger.info(f"repo: {model['repo']}") + logger.info(f"chktok: {chktok}") + logger.info(f"chkhsh: {chkhsh}") - # print the "pre_tokenizer" content from the tokenizer.json - with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f: - cfg = json.load(f) - normalizer = cfg["normalizer"] - logger.info("normalizer: " + json.dumps(normalizer, indent=4)) - pre_tokenizer = cfg["pre_tokenizer"] - logger.info("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4)) - if "ignore_merges" in cfg["model"]: - logger.info("ignore_merges: " + json.dumps(cfg["model"]["ignore_merges"], indent=4)) + # print the "pre_tokenizer" content from the tokenizer.json + with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f: + cfg = json.load(f) + normalizer = cfg["normalizer"] + logger.info("normalizer: " + json.dumps(normalizer, indent=4)) + pre_tokenizer = cfg["pre_tokenizer"] + logger.info("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4)) + if "ignore_merges" in cfg["model"]: + logger.info("ignore_merges: " + json.dumps(cfg["model"]["ignore_merges"], indent=4)) - logger.info("") + logger.info("") src_ifs += f" if chkhsh == \"{chkhsh}\":\n" src_ifs += f" # ref: {model['repo']}\n" @@ -271,8 +321,6 @@ src_func = f""" return res """ -convert_py_pth = pathlib.Path("convert_hf_to_gguf.py") -convert_py = convert_py_pth.read_text(encoding="utf-8") convert_py = re.sub( r"(# Marker: Start get_vocab_base_pre)(.+?)( +# Marker: End get_vocab_base_pre)", lambda m: m.group(1) + src_func + m.group(3), @@ -288,7 +336,7 @@ logger.info("+++ convert_hf_to_gguf.py was updated") tests = [ "ied 4 ½ months", - "Führer", + "Äpfel", "", " ", " ", @@ -367,6 +415,10 @@ for model in models: logger.error(f"Failed to load tokenizer for model {name}. Error: {e}") continue # Skip this model and continue with the next one in the loop + if not os.path.exists(f"models/ggml-vocab-{name}.gguf"): + logger.info(f"Skip vocab files for model {name}, no GGUF file found") + continue + with open(f"models/ggml-vocab-{name}.gguf.inp", "w", encoding="utf-8") as f: for text in tests: f.write(f"{text}") diff --git a/docs/backend/CANN.md b/docs/backend/CANN.md old mode 100644 new mode 100755 index e172ec5c2..a5ba617ca --- a/docs/backend/CANN.md +++ b/docs/backend/CANN.md @@ -280,6 +280,15 @@ cmake --build build --config release ### **GitHub contribution**: Please add the **[CANN]** prefix/tag in issues/PRs titles to help the CANN-team check/address them without delay. +## Updates +### Basic Flash Attention Support +The basic FA kernel with aclnnops has been added in aclnn_ops.cpp. +Currently, the FA only supports the cases with FP16 KV tensors and NO logit softcap. +Since the aclnn interface for flash attention cannot support the logit softcap, we will only update the quantized version in the future. + +Authors from Peking University: Bizhao Shi (bshi@pku.edu.cn), Yuxin Yang (yxyang@pku.edu.cn), Ruiyang Ma (ruiyang@stu.pku.edu.cn), and Guojie Luo (gluo@pku.edu.cn). + +We would like to thank Tuo Dai, Shanni Li, and all of the project maintainers from Huawei Technologies Co., Ltd for their help during the code development and pull request. ## TODO - Support more models and data types. diff --git a/docs/build.md b/docs/build.md index c9027c0b5..32717a793 100644 --- a/docs/build.md +++ b/docs/build.md @@ -63,6 +63,7 @@ cmake --build build --config Release cmake --preset x64-windows-llvm-release cmake --build build-x64-windows-llvm-release ``` +- Curl usage is enabled by default and can be turned off with `-DLLAMA_CURL=OFF`. Otherwise you need to install development libraries for libcurl. ## BLAS Build diff --git a/docs/function-calling.md b/docs/function-calling.md index c3873c3fa..fd3db9bd1 100644 --- a/docs/function-calling.md +++ b/docs/function-calling.md @@ -2,7 +2,6 @@ [chat.h](../common/chat.h) (https://github.com/ggml-org/llama.cpp/pull/9639) adds support for [OpenAI-style function calling](https://platform.openai.com/docs/guides/function-calling) and is used in: - `llama-server` when started w/ `--jinja` flag -- `llama-cli` (WIP: https://github.com/ggml-org/llama.cpp/pull/11556) ## Universal support w/ Native & Generic handlers @@ -325,36 +324,65 @@ To get the official template from original HuggingFace repos, you can use [scrip > [!TIP] > If there is no official `tool_use` Jinja template, you may want to set `--chat-template chatml` to use a default that works with many models (YMMV!), or write your own (e.g. we provide a custom [llama-cpp-deepseek-r1.jinja](../models/templates/llama-cpp-deepseek-r1.jinja) for DeepSeek R1 distills) +> [!CAUTION] +> Beware of extreme KV quantizations (e.g. `-ctk q4_0`), they can substantially degrade the model's tool calling performance. + Test in CLI (or with any library / software that can use OpenAI-compatible API backends): ```bash curl http://localhost:8080/v1/chat/completions -d '{ -"model": "gpt-3.5-turbo", -"tools": [ - { - "type":"function", - "function":{ - "name":"python", - "description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", - "parameters":{ - "type":"object", - "properties":{ - "code":{ - "type":"string", - "description":"The code to run in the ipython interpreter." + "model": "gpt-3.5-turbo", + "tools": [ + { + "type":"function", + "function":{ + "name":"python", + "description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", + "parameters":{ + "type":"object", + "properties":{ + "code":{ + "type":"string", + "description":"The code to run in the ipython interpreter." + } + }, + "required":["code"] } - }, - "required":["code"] } - } - } -], -"messages": [ - { - "role": "user", - "content": "Print a hello world message with python." - } -] + } + ], + "messages": [ + { + "role": "user", + "content": "Print a hello world message with python." + } + ] +}' + + +curl http://localhost:8080/v1/chat/completions -d '{ + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."}, + {"role": "user", "content": "What is the weather in Istanbul?"} + ], + "tools": [{ + "type":"function", + "function":{ + "name":"get_current_weather", + "description":"Get the current weather in a given location", + "parameters":{ + "type":"object", + "properties":{ + "location":{ + "type":"string", + "description":"The city and country/state, e.g. `San Francisco, CA`, or `Paris, France`" + } + }, + "required":["location"] + } + } + }] }' ``` diff --git a/docs/multimodal.md b/docs/multimodal.md index 054778e91..e849c2a0b 100644 --- a/docs/multimodal.md +++ b/docs/multimodal.md @@ -4,7 +4,9 @@ llama.cpp supports multimodal input via `libmtmd`. Currently, there are 2 tools - [llama-mtmd-cli](../tools/mtmd/README.md) - [llama-server](../tools/server/README.md) via OpenAI-compatible `/chat/completions` API -To enable it, can use use one of the 2 methods below: +Currently, we support **image** and **audio** input. Audio is highly experimental and may have reduced quality. + +To enable it, you can use one of the 2 methods below: - Use `-hf` option with a supported model (see a list of pre-quantized model below) - To load a model using `-hf` while disabling multimodal, use `--no-mmproj` @@ -31,12 +33,14 @@ llama-server -hf ggml-org/gemma-3-4b-it-GGUF --no-mmproj-offload ## Pre-quantized models -These are ready-to-use models, most of them come with `Q4_K_M` quantization by default. They can be found at the Hugging Face page of the ggml-org: https://huggingface.co/ggml-org +These are ready-to-use models, most of them come with `Q4_K_M` quantization by default. They can be found at the Hugging Face page of the ggml-org: https://huggingface.co/collections/ggml-org/multimodal-ggufs-68244e01ff1f39e5bebeeedc Replaces the `(tool_name)` with the name of binary you want to use. For example, `llama-mtmd-cli` or `llama-server` NOTE: some models may require large context window, for example: `-c 8192` +**Vision models**: + ```sh # Gemma 3 (tool_name) -hf ggml-org/gemma-3-4b-it-GGUF @@ -77,4 +81,29 @@ NOTE: some models may require large context window, for example: `-c 8192` # Llama 4 Scout (tool_name) -hf ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF + +# Moondream2 20250414 version +(tool_name) -hf ggml-org/moondream2-20250414-GGUF + +``` + +**Audio models**: + +```sh +# Ultravox 0.5 +(tool_name) -hf ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF +(tool_name) -hf ggml-org/ultravox-v0_5-llama-3_1-8b-GGUF + +# Qwen2-Audio and SeaLLM-Audio +# note: no pre-quantized GGUF this model, as they have very poor result +# ref: https://github.com/ggml-org/llama.cpp/pull/13760 +``` + +**Mixed modalities**: + +```sh +# Qwen2.5 Omni +# Capabilities: audio input, vision input +(tool_name) -hf ggml-org/Qwen2.5-Omni-3B-GGUF +(tool_name) -hf ggml-org/Qwen2.5-Omni-7B-GGUF ``` diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 01ff6763f..71f700877 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -41,8 +41,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu // run model LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); - if (llama_encode(ctx, batch) < 0) { - LOG_ERR("%s : failed to encode\n", __func__); + if (llama_decode(ctx, batch) < 0) { + LOG_ERR("%s : failed to process\n", __func__); } for (int i = 0; i < batch.n_tokens; i++) { diff --git a/examples/parallel/README.md b/examples/parallel/README.md index ece3a6641..2468a30d2 100644 --- a/examples/parallel/README.md +++ b/examples/parallel/README.md @@ -4,7 +4,7 @@ Simplified simulation of serving incoming requests in parallel ## Example -Generate 128 client requests (`-ns 128`), simulating 8 concurrent clients (`-np 8`). The system prompt is shared (`-pps`), meaning that it is computed once at the start. The client requests consist of 10 junk questions (`-j 10`) followed by the actual question. +Generate 128 client requests (`-ns 128`), simulating 8 concurrent clients (`-np 8`). The system prompt is shared (`-pps`), meaning that it is computed once at the start. The client requests consist of up to 10 junk questions (`--junk 10`) followed by the actual question. ```bash llama-parallel -m model.gguf -np 8 -ns 128 --top-k 1 -pps --junk 10 -c 16384 diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index acb1301a2..cd85bea9a 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -158,7 +158,7 @@ int main(int argc, char ** argv) { common_params params; params.n_predict = 128; - params.n_junk = 0; + params.n_junk = 1; if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) { return 1; @@ -182,7 +182,7 @@ int main(int argc, char ** argv) { const bool is_sp_shared = params.is_pp_shared; // extra text to insert in each client's prompt in order to make it larger - const int32_t n_junk = params.n_junk; + const int32_t n_junk = std::max(1, params.n_junk); // init llama.cpp llama_backend_init(); @@ -315,7 +315,10 @@ int main(int argc, char ** argv) { } else { client.prompt += k_system; } - for (int i = 0; i < n_junk; ++i) { + + const int n_junk_cur = rand() % n_junk; + + for (int i = 0; i < n_junk_cur; ++i) { const int r = rand() % k_questions.size(); client.prompt += "User:\n" + k_questions[r] + "\nAssistant:\n " + k_answers[r] + "\n"; } @@ -340,7 +343,7 @@ int main(int argc, char ** argv) { client.n_decoded = 0; client.i_batch = batch.n_tokens - 1; - LOG_INF("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id); + LOG_INF("\033[31mClient %3d, seq %4d, junk = %4d, started decoding ...\033[0m\n", client.id, client.seq_id, n_junk_cur); g_seq_id += 1; @@ -359,7 +362,9 @@ int main(int argc, char ** argv) { // process in chunks of params.n_batch int32_t n_batch = params.n_batch; - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { + int32_t i_next = 0; + + for (int32_t i = 0; i < batch.n_tokens; i = i_next) { // experiment: process in powers of 2 //if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) { // n_batch /= 2; @@ -367,7 +372,7 @@ int main(int argc, char ** argv) { // continue; //} - const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); llama_batch batch_view = { n_tokens, @@ -387,19 +392,24 @@ int main(int argc, char ** argv) { return 1; } - LOG_ERR("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2); + LOG_WRN("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2); n_cache_miss += 1; // retry with half the batch size to try to find a free slot in the KV cache n_batch /= 2; - i -= n_batch; continue; } LOG_DBG("%s : decoded batch of %d tokens\n", __func__, n_tokens); + // move the head of the batch forward with the number of tokens we just processed + i_next = i + n_tokens; + + // on successful decode, restore the original batch size + n_batch = params.n_batch; + for (auto & client : clients) { if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) { continue; diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 347ea4a69..5ac881b45 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -133,9 +133,8 @@ int main(int argc, char ** argv) { const int ib = i/n_batch - 1; const int bd = n_batch_grp*(n_grp - 1); - llama_kv_self_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd); - llama_kv_self_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); - llama_kv_self_update (ctx); + llama_kv_self_seq_add(ctx, 0, n_past - n_batch, n_past, ib*bd); + llama_kv_self_seq_div(ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1; } @@ -169,8 +168,6 @@ int main(int argc, char ** argv) { llama_kv_self_seq_rm (ctx, 0, n_keep , n_keep + n_discard); llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); - //llama_kv_self_defrag (ctx); - llama_kv_self_update (ctx); n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1; @@ -200,8 +197,6 @@ int main(int argc, char ** argv) { llama_kv_self_seq_rm (ctx, 0, n_keep , n_keep + n_discard); llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); - //llama_kv_self_defrag (ctx); - llama_kv_self_update (ctx); n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1; } diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index e3d0c9542..754da1411 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -81,14 +81,14 @@ static void batch_add_seq(llama_batch & batch, const std::vector & toke } } -static void batch_encode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { +static void batch_process(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { // clear previous kv_cache values (irrelevant for embeddings) llama_kv_self_clear(ctx); // run model LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); - if (llama_encode(ctx, batch) < 0) { - LOG_ERR("%s : failed to encode\n", __func__); + if (llama_decode(ctx, batch) < 0) { + LOG_ERR("%s : failed to process\n", __func__); } for (int i = 0; i < batch.n_tokens; i++) { @@ -233,7 +233,7 @@ int main(int argc, char ** argv) { // encode if at capacity if (batch.n_tokens + n_toks > n_batch) { float * out = emb + p * n_embd; - batch_encode(ctx, batch, out, s, n_embd); + batch_process(ctx, batch, out, s, n_embd); common_batch_clear(batch); p += s; s = 0; @@ -246,7 +246,7 @@ int main(int argc, char ** argv) { // final batch float * out = emb + p * n_embd; - batch_encode(ctx, batch, out, s, n_embd); + batch_process(ctx, batch, out, s, n_embd); // save embeddings to chunks for (int i = 0; i < n_chunks; i++) { @@ -267,7 +267,7 @@ int main(int argc, char ** argv) { batch_add_seq(query_batch, query_tokens, 0); std::vector query_emb(n_embd, 0); - batch_encode(ctx, query_batch, query_emb.data(), 1, n_embd); + batch_process(ctx, query_batch, query_emb.data(), 1, n_embd); common_batch_clear(query_batch); diff --git a/examples/training/README.md b/examples/training/README.md index ecdf398f8..df4252792 100644 --- a/examples/training/README.md +++ b/examples/training/README.md @@ -10,8 +10,8 @@ Proof of concept: ``` sh export model_name=llama_3.2-1b && export quantization=f32 -./build/bin/finetune --file wikitext-2-raw/wiki.test.raw -ngl 999 --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512 -./build/bin/perplexity --file wikitext-2-raw/wiki.test.raw -ngl 999 --model finetuned-model.gguf +./build/bin/llama-finetune --file wikitext-2-raw/wiki.test.raw -ngl 999 --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512 +./build/bin/llama-perplexity --file wikitext-2-raw/wiki.test.raw -ngl 999 --model finetuned-model.gguf ``` The perplexity value of the finetuned model should be lower after training on the test set for 2 epochs. diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 4746d5cb7..3d01184a2 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -129,6 +129,7 @@ option(GGML_LASX "ggml: enable lasx" ON) option(GGML_LSX "ggml: enable lsx" ON) option(GGML_RVV "ggml: enable rvv" ON) option(GGML_RV_ZFH "ggml: enable riscv zfh" OFF) +option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF) option(GGML_VXE "ggml: enable vxe" ON) option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF) @@ -176,7 +177,6 @@ option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" option(GGML_VULKAN_DEBUG "ggml: enable Vulkan debug output" OFF) option(GGML_VULKAN_MEMORY_DEBUG "ggml: enable Vulkan memory debug output" OFF) option(GGML_VULKAN_SHADER_DEBUG_INFO "ggml: enable Vulkan shader debug info" OFF) -option(GGML_VULKAN_PERF "ggml: enable Vulkan perf output" OFF) option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" OFF) option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF) option(GGML_KOMPUTE "ggml: use Kompute" OFF) diff --git a/ggml/cmake/common.cmake b/ggml/cmake/common.cmake index 1976d0ae9..bb1ec9b37 100644 --- a/ggml/cmake/common.cmake +++ b/ggml/cmake/common.cmake @@ -24,3 +24,28 @@ function(ggml_get_flags CCID CCVER) set(GF_C_FLAGS ${C_FLAGS} PARENT_SCOPE) set(GF_CXX_FLAGS ${CXX_FLAGS} PARENT_SCOPE) endfunction() + +function(ggml_get_system_arch) + if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR + CMAKE_GENERATOR_PLATFORM_LWR STREQUAL "arm64" OR + (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND + CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm.*|ARM64)$")) + set(GGML_SYSTEM_ARCH "ARM" PARENT_SCOPE) + elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR + CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR + (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND + CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64|amd64)$")) + set(GGML_SYSTEM_ARCH "x86" PARENT_SCOPE) + elseif ("${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "ppc64le " OR + "${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "powerpc ") + set(GGML_SYSTEM_ARCH "PowerPC" PARENT_SCOPE) + elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64") + set(GGML_SYSTEM_ARCH "loongarch64" PARENT_SCOPE) + elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "riscv64") + set(GGML_SYSTEM_ARCH "riscv64" PARENT_SCOPE) + elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x") + set(GGML_SYSTEM_ARCH "s390x" PARENT_SCOPE) + else() + set(GGML_SYSTEM_ARCH "UNKNOWN" PARENT_SCOPE) + endif() +endfunction() diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index c81ff03fe..1a57f1cd7 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -528,15 +528,15 @@ extern "C" { GGML_UNARY_OP_STEP, GGML_UNARY_OP_TANH, GGML_UNARY_OP_ELU, + GGML_UNARY_OP_RELU, GGML_UNARY_OP_SIGMOID, GGML_UNARY_OP_GELU, - GGML_UNARY_OP_GELU_ERF, GGML_UNARY_OP_GELU_QUICK, GGML_UNARY_OP_SILU, GGML_UNARY_OP_HARDSWISH, GGML_UNARY_OP_HARDSIGMOID, GGML_UNARY_OP_EXP, - GGML_UNARY_OP_RELU, + GGML_UNARY_OP_GELU_ERF, GGML_UNARY_OP_COUNT, }; @@ -935,6 +935,15 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + // repeat a to the specified shape + GGML_API struct ggml_tensor * ggml_repeat_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3); + // sums repetitions in a into shape of b GGML_API struct ggml_tensor * ggml_repeat_back( struct ggml_context * ctx, @@ -2086,9 +2095,6 @@ extern "C" { GGML_API struct ggml_tensor * ggml_graph_get_grad (const struct ggml_cgraph * cgraph, const struct ggml_tensor * node); GGML_API struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node); - GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname); - GGML_API struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval); - // print info and performance information for the graph GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph); @@ -2172,6 +2178,7 @@ extern "C" { // scheduling priorities enum ggml_sched_priority { + GGML_SCHED_PRIO_LOW = -1, GGML_SCHED_PRIO_NORMAL, GGML_SCHED_PRIO_MEDIUM, GGML_SCHED_PRIO_HIGH, diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index ddea5ad38..76b24bd9d 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -109,6 +109,8 @@ if (MSVC) else () set(CMAKE_GENERATOR_PLATFORM_LWR "") endif () +ggml_get_system_arch() +message(STATUS "GGML_SYSTEM_ARCH: ${GGML_SYSTEM_ARCH}") if (NOT MSVC) if (GGML_STATIC) @@ -194,6 +196,7 @@ add_library(ggml-base ../include/ggml-opt.h ../include/gguf.h ggml.c + ggml.cpp ggml-alloc.c ggml-backend.cpp ggml-opt.cpp @@ -224,6 +227,7 @@ function(ggml_add_backend_library backend) set_target_properties(${backend} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) target_compile_definitions(${backend} PRIVATE GGML_BACKEND_DL) add_dependencies(ggml ${backend}) + install(TARGETS ${backend} LIBRARY DESTINATION ${CMAKE_INSTALL_BINDIR}) else() add_library(${backend} ${ARGN}) target_link_libraries(ggml PUBLIC ${backend}) @@ -287,16 +291,20 @@ if (GGML_CPU_ALL_VARIANTS) if (NOT GGML_BACKEND_DL) message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS requires GGML_BACKEND_DL") endif() - ggml_add_cpu_backend_variant(x64) - ggml_add_cpu_backend_variant(sse42 SSE42) - ggml_add_cpu_backend_variant(sandybridge SSE42 AVX) - ggml_add_cpu_backend_variant(haswell SSE42 AVX F16C AVX2 BMI2 FMA) - ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512) - ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI) - ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI) - if (NOT MSVC) - # MSVC doesn't support AMX - ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8) + if (GGML_SYSTEM_ARCH STREQUAL "x86") + ggml_add_cpu_backend_variant(x64) + ggml_add_cpu_backend_variant(sse42 SSE42) + ggml_add_cpu_backend_variant(sandybridge SSE42 AVX) + ggml_add_cpu_backend_variant(haswell SSE42 AVX F16C AVX2 BMI2 FMA) + ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512) + ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI) + ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI) + if (NOT MSVC) + # MSVC doesn't support AMX + ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8) + endif() + else() + message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported on ${GGML_SYSTEM_ARCH}") endif() elseif (GGML_CPU) ggml_add_cpu_backend_variant_impl("") diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index b30b4cb38..b1050ad59 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1340,7 +1340,10 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) { // allocate graph if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) { // the re-allocation may cause the split inputs to be moved to a different address - ggml_backend_sched_synchronize(sched); + // synchronize without ggml_backend_sched_synchronize to avoid changing cur_copy + for (int i = 0; i < sched->n_backends; i++) { + ggml_backend_synchronize(sched->backends[i]); + } #ifndef NDEBUG GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed); #endif @@ -1564,7 +1567,6 @@ bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgra ggml_backend_sched_split_graph(sched, graph); - if (!ggml_backend_sched_alloc_splits(sched)) { return false; } @@ -1598,6 +1600,12 @@ void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) { for (int i = 0; i < sched->n_backends; i++) { ggml_backend_synchronize(sched->backends[i]); } + if (!sched->is_alloc) { + // if the graph is not already allocated, always use copy 0 after a synchronization + // this ensures that during generation the same copy is used every time, + // which avoids changes in the graph that could cause CUDA or other graphs to be disabled + sched->cur_copy = 0; + } } void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) { diff --git a/ggml/src/ggml-blas/CMakeLists.txt b/ggml/src/ggml-blas/CMakeLists.txt index 0bf3c05d9..76064c3fd 100644 --- a/ggml/src/ggml-blas/CMakeLists.txt +++ b/ggml/src/ggml-blas/CMakeLists.txt @@ -81,7 +81,7 @@ if (BLAS_FOUND) target_link_libraries (ggml-blas PRIVATE ${BLAS_LIBRARIES}) target_include_directories(ggml-blas PRIVATE ${BLAS_INCLUDE_DIRS}) else() - message(ERROR "BLAS not found, please refer to " - "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" - " to set correct GGML_BLAS_VENDOR") + message(FATAL_ERROR "BLAS not found, please refer to " + "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" + " to set correct GGML_BLAS_VENDOR") endif() diff --git a/ggml/src/ggml-cann/CMakeLists.txt b/ggml/src/ggml-cann/CMakeLists.txt old mode 100644 new mode 100755 index 0d8e483b2..7742b3915 --- a/ggml/src/ggml-cann/CMakeLists.txt +++ b/ggml/src/ggml-cann/CMakeLists.txt @@ -30,6 +30,7 @@ string(TOLOWER ${SOC_TYPE} SOC_VERSION) # SOC_VERSION need lower string(REGEX MATCH "[0-9]+[a-zA-Z]" SOC_TYPE_MAJOR_SN "${SOC_VERSION}") set(SOC_TYPE_COMPILE_OPTION "ASCEND_${SOC_TYPE_MAJOR_SN}") string(TOUPPER ${SOC_TYPE_COMPILE_OPTION} SOC_TYPE_COMPILE_OPTION) +message(STATUS "CANN: SOC_VERSION = ${SOC_VERSION}") if (CANN_INSTALL_DIR) # Only Support Linux. diff --git a/ggml/src/ggml-cann/Doxyfile b/ggml/src/ggml-cann/Doxyfile old mode 100644 new mode 100755 diff --git a/ggml/src/ggml-cann/acl_tensor.cpp b/ggml/src/ggml-cann/acl_tensor.cpp old mode 100644 new mode 100755 index f5462c5a1..f311864d4 --- a/ggml/src/ggml-cann/acl_tensor.cpp +++ b/ggml/src/ggml-cann/acl_tensor.cpp @@ -31,6 +31,8 @@ aclDataType ggml_cann_type_mapping(ggml_type type) { return ACL_FLOAT; case GGML_TYPE_F16: return ACL_FLOAT16; + case GGML_TYPE_BF16: + return ACL_BF16; case GGML_TYPE_I8: return ACL_INT8; case GGML_TYPE_I16: diff --git a/ggml/src/ggml-cann/acl_tensor.h b/ggml/src/ggml-cann/acl_tensor.h old mode 100644 new mode 100755 diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp old mode 100644 new mode 100755 index cbf9783b7..437ece2d4 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -66,6 +66,7 @@ #include #include #include +#include #include #include @@ -74,11 +75,13 @@ #include #include "ggml-impl.h" +#include "ggml.h" #define GGML_COMMON_DECL_C #include "../ggml-common.h" + void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclTensor ** acl_src0, aclTensor ** acl_src1, aclTensor ** acl_dst) { GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_can_repeat(src1, src0)); @@ -2697,14 +2700,10 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* } } - // GroupedMatmulV2 required tensor_list.size < 128 size_t GROUP_SIZE = 128; - std::vector> src0_tensor_vec_vec; - std::vector> src1_tensor_vec_vec; - std::vector> dst_tensor_vec_vec; - - // split and call GroupedMatmulV2 + // GroupedMatmulV2 required tensor_list.size < 128 for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) { + // split and call GroupedMatmulV2 size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size()); std::vector src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end); std::vector src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end); @@ -2722,6 +2721,133 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* return; } +/** + * @brief Performs expert-specific matrix multiplication (MoE) with + * quantized precision using the CANN backend. + * + * This function executes a matrix multiplication operation tailored for + * Mixture of Experts (MoE) models, where the input tensor is multiplied + * with expert-specific quantized weight matrices. It leverages the CANN + * backend to perform efficient low-precision computations and stores the + * quantized result in the destination tensor `dst`. + * + * Quantization techniques reduce memory footprint and improve performance + * by using lower-bit representations (e.g., int8) instead of floating-point. + * This function is designed to work with such formats and may incorporate + * optimizations like identity-based fast paths or routing masks for sparse + * expert selection. + * + * @param ctx The context for executing CANN backend operations. + * @param dst The destination tensor where the quantized MoE multiplication result + * will be stored. + * + * @note This function assumes quantized data types and is designed for + * MoE architectures with potential sparse expert routing. + */ +static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + // TODO: Use aclnnGroupedMatMul + //dst [M, K, N, 1] + ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1] + ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1 + ggml_tensor * ids = dst->src[2]; //ids [K, N] + + GGML_TENSOR_BINARY_OP_LOCALS + + // copy index from npu to cpu + int64_t n_as = ne02; // A + int64_t n_ids = ids->ne[0]; // K + + std::vector ids_host(ggml_nbytes(ids)); + ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids), + ACL_MEMCPY_DEVICE_TO_HOST); + ACL_CHECK(aclrtSynchronizeStream(ctx.stream())); + + char * src0_original = (char *) src0->data; + char * src1_original = (char *) src1->data; + char * dst_original = (char *) dst->data; + + ggml_tensor src0_row = *src0; + ggml_tensor src1_row = *src1; + ggml_tensor dst_row = *dst; + + const enum ggml_type type = dst->src[0]->type; + float weight_elem_size; + if (type == GGML_TYPE_Q4_0) { + weight_elem_size = float(sizeof(uint8_t)) / 2; + } else if (type == GGML_TYPE_Q8_0) { + weight_elem_size = float(sizeof(uint8_t)); + } else { + GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 "); + } + + // src0_row [D, M, 1, 1] weight without permute + src0_row.ne[2] = 1; + src0_row.ne[3] = 1; + src0_row.nb[0] = weight_elem_size; + src0_row.nb[1] = weight_elem_size * ne00; + src0_row.nb[2] = weight_elem_size * ne00; + src0_row.nb[3] = weight_elem_size * ne00; + size_t weight_stride = ne00 * ne01 * weight_elem_size; + size_t weight_size = weight_stride * ne02 * ne03; + + // scale [D, M, 1, 1] -> scale && permute + size_t scale_elem_size = sizeof(uint16_t); + size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size; + + // src1_row [D, 1, 1, 1] -> input + src1_row.ne[1] = 1; + src1_row.ne[2] = 1; + src1_row.ne[3] = 1; + src1_row.nb[2] = nb11; + src1_row.nb[3] = nb11; + + // dst_row [M, 1, 1, 1] -> out + dst_row.ne[1] = 1; + dst_row.ne[2] = 1; + dst_row.ne[3] = 1; + dst_row.nb[2] = nb1; + dst_row.nb[3] = nb1; + + //create weight for one row + ggml_cann_pool_alloc weight_allocator(ctx.pool()); + void* weight_buffer = weight_allocator.alloc(nb02); + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + // expert index + int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + // If B = 1 (broadcast), always use 0; otherwise, use id. + int64_t i11 = (ne11 == 1 ? 0 : id); + int64_t i12 = iid1; + + int64_t i1 = id; + int64_t i2 = i12; + + void* src0_tmp_ptr = src0_original + i02*weight_stride; + void* scale_tmp_ptr = src0_original + weight_size + i02*scale_stride; + void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12; + void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2; + + // mem cpy + ggml_cann_async_memcpy(ctx, weight_buffer, src0_tmp_ptr, weight_stride, + ACL_MEMCPY_DEVICE_TO_DEVICE); + void* scale_buffer = (char*)weight_buffer + weight_stride; + ggml_cann_async_memcpy(ctx, scale_buffer, scale_tmp_ptr, scale_stride, + ACL_MEMCPY_DEVICE_TO_DEVICE); + + src0_row.data = weight_buffer; + src1_row.data = src1_tmp_ptr; + dst_row.data = dst_tmp_ptr; + dst_row.src[0] = &src0_row; + dst_row.src[1] = &src1_row; + + ggml_cann_mul_mat(ctx, &dst_row); + } + } + return; +} + void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) { const enum ggml_type type = dst->src[0]->type; switch (type) { @@ -2729,8 +2855,339 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) { case GGML_TYPE_F16: ggml_cann_mul_mat_id_fp(ctx, dst); break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + ggml_cann_mul_mat_id_quant(ctx, dst); + break; default: GGML_ABORT("Unsupported type for mul_mat_id"); break; } } + +void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ + + ggml_tensor* src0 = dst->src[0]; // q, fp32 + ggml_tensor* src1 = dst->src[1]; // k, fp16 + ggml_tensor* src2 = dst->src[2]; // v, fp16 + ggml_tensor* src3 = dst->src[3]; // mask, fp16 + + float maxBias = 0.0f; + float scaleValue = 1.0f; + float logitSoftcap = 0.0f; + memcpy(&scaleValue, (float*)dst->op_params + 0, sizeof(float)); + memcpy(&maxBias, (float*)dst->op_params + 1, sizeof(float)); + memcpy(&logitSoftcap, (float*)dst->op_params + 2, sizeof(float)); + + if(logitSoftcap == 0.0f){ + size_t faElemSize = sizeof(uint16_t); + auto faDataType = ACL_FLOAT16; //ACL_BF16; + + aclTensor* acl_src0_f16_tensor = nullptr; + aclTensor* acl_src1_f16_tensor = nullptr; + aclTensor* acl_src2_f16_tensor = nullptr; + aclTensor* acl_dst_f16_tensor = nullptr; + + // Step 1: cast the src0 (Query) to fp16 if needed + ggml_cann_pool_alloc src0_f16_allocator(ctx.pool()); + void* src0_f16_buffer = nullptr; + + if(ggml_cann_type_mapping(src0->type) != faDataType){ + aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0); + src0_f16_buffer = src0_f16_allocator.alloc( + ggml_nelements(src0) * faElemSize); + + int64_t* src0_f16_ne = src0->ne; + size_t src0_f16_nb[GGML_MAX_DIMS]; + src0_f16_nb[0] = sizeof(uint16_t); + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + src0_f16_nb[i] = src0_f16_nb[i - 1] * src0_f16_ne[i - 1]; + } + + acl_src0_f16_tensor = ggml_cann_create_tensor( + src0_f16_buffer, faDataType, faElemSize, + src0_f16_ne, src0_f16_nb, GGML_MAX_DIMS + ); + aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType); + ggml_cann_release_resources(ctx, acl_src0_f32_tensor); + }else{ + acl_src0_f16_tensor = ggml_cann_create_tensor(src0); + } + + // Step 2: create the acl tensors for src1 (Key), src2 (Value), + // and the direct output from FusedInferAttention + + acl_src1_f16_tensor = ggml_cann_create_tensor(src1); + acl_src2_f16_tensor = ggml_cann_create_tensor(src2); + + ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); + void* out_f16_buffer = out_f16_allocator.alloc( + ggml_nelements(dst) * faElemSize); + + int64_t* out_f16_ne = src0->ne; + size_t out_f16_nb[GGML_MAX_DIMS]; + out_f16_nb[0] = faElemSize; + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1]; + } + + acl_dst_f16_tensor = ggml_cann_create_tensor( + out_f16_buffer, faDataType, faElemSize, + out_f16_ne, out_f16_nb, GGML_MAX_DIMS + ); + + // Step 3: create the PSEShift tensor if needed + // this tensor is considered as mask (f16) in the llama.cpp + + aclTensor* bcast_pse_tensor = nullptr; + int64_t bcast_pse_ne[GGML_MAX_DIMS]; + size_t bcast_pse_nb[GGML_MAX_DIMS]; + ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool()); + void* bcast_pse_buffer = nullptr; + + if(src3 != nullptr){ + bcast_pse_buffer = bcast_pse_allocator.alloc( + ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t)); + + if(src0->ne[1] > 1){ + // Case 1: broadcast pse for prefill stage with multiple head + aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3); + bcast_pse_ne[0] = src3->ne[0]; + bcast_pse_ne[1] = src3->ne[1]; + bcast_pse_ne[2] = src0->ne[2]; + bcast_pse_ne[3] = src3->ne[3]; + + bcast_pse_nb[0] = sizeof(uint16_t); + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1]; + } + + bcast_pse_tensor = ggml_cann_create_tensor( + bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), + bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS); + + int64_t repeats[] = {1, src0->ne[2], 1, 1}; + aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats); + + ggml_cann_release_resources(ctx, acl_mask_f16_tensor); + }else{ + // Case 2: trunc the first row and broadcast pse for decode stage with multiple head + int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]}; + size_t* trunc_pse_nb = src3->nb; + + aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor( + src3->data, ACL_FLOAT16, sizeof(uint16_t), + trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS); + + bcast_pse_ne[0] = src3->ne[0]; + bcast_pse_ne[1] = src0->ne[1]; + bcast_pse_ne[2] = src0->ne[2]; + bcast_pse_ne[3] = src3->ne[3]; + + bcast_pse_nb[0] = sizeof(uint16_t); + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1]; + } + + bcast_pse_tensor = ggml_cann_create_tensor( + bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), + bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS); + + int64_t repeats[] = {1, src0->ne[2], 1, 1}; + aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats); + + ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor); + } + + // Compute the slope if needed. Derived from ggml_cann_softmax(). + if(maxBias != 0.0f){ + // alibi + const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3]; + const int64_t n_head = src0->ne[2]; + const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head)); + float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor); + float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor); + // init arange + ggml_cann_pool_alloc arange_allocator(ctx.pool(), + ne2_ne3 * faElemSize); + void* tmp_arange_buffer = arange_allocator.get(); + + // arange1: [1, ..., n_heads_log2_floor+1) + float start = 1; + float stop = n_heads_log2_floor + 1; + float step = 1; + int64_t n_elements_arange = n_heads_log2_floor; + + int64_t tmp_arange1_ne[] = {n_heads_log2_floor}; + size_t tmp_arange1_nb[] = {faElemSize}; + aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor( + tmp_arange_buffer, faDataType, faElemSize, + tmp_arange1_ne, tmp_arange1_nb, + GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + + aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange); + + aclTensor* tmp_arange2_tensor = nullptr; + if (n_heads_log2_floor < ne2_ne3) { + // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1) + start = 1; + stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1; + step = 2; + n_elements_arange = ne2_ne3 - n_heads_log2_floor; + int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor}; + size_t tmp_arange2_nb[] = {faElemSize}; + + aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor( + (char*)tmp_arange_buffer + + n_heads_log2_floor * faElemSize, + faDataType, faElemSize, + tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step, + n_elements_arange); + } + + // init mk_base + ggml_cann_pool_alloc mk_base_allocator(ctx.pool(), + ne2_ne3 * faElemSize); + void* tmp_mk_base_buffer = mk_base_allocator.get(); + int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor}; + size_t tmp_mk_base1_nb[] = {faElemSize}; + aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor( + tmp_mk_base_buffer, faDataType, faElemSize, + tmp_mk_base1_ne, tmp_mk_base1_nb, + GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + + aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor); + + aclTensor* tmp_mk_base2_tensor = nullptr; + if (n_heads_log2_floor < ne2_ne3) { + int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor}; + size_t tmp_mk_base2_nb[] = {faElemSize}; + aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor( + (char*)tmp_mk_base_buffer + + n_heads_log2_floor * faElemSize, + faDataType, faElemSize, + tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor); + } + + // init mk + int64_t tmp_mk_base_ne[] = {ne2_ne3}; + size_t tmp_mk_base_nb[] = {faElemSize}; + aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor( + tmp_mk_base_buffer, faDataType, faElemSize, + tmp_mk_base_ne, tmp_mk_base_nb, + GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + aclTensor* tmp_arange_tensor = ggml_cann_create_tensor( + tmp_arange_buffer, faDataType, faElemSize, + tmp_mk_base_ne, tmp_mk_base_nb, + GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor); + + // reshape mk + int64_t tmp_mk_ne[] = {1, 1, src0->ne[2], src0->ne[3]}; + size_t tmp_mk_nb[GGML_MAX_DIMS]; + tmp_mk_nb[0] = faElemSize; + for (int i = 1; i < GGML_MAX_DIMS; i++) { + tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1]; + } + aclTensor* tmp_mk_tensor = ggml_cann_create_tensor( + tmp_mk_base_buffer, faDataType, faElemSize, + tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS, + ACL_FORMAT_ND); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor); + + ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor, + tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor, + tmp_arange_tensor, tmp_mk_tensor); + } + } + + // Step 4: set the inputs for FusedInferAttention. + int kvTensorNum = 1; + aclTensor* acl_q_tensor = acl_src0_f16_tensor; + aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor}; + aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor}; + auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum); + auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum); + + int64_t numHeads = src0->ne[2]; // N + int64_t numKeyValueHeads = src1->ne[2]; + // double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d) + int64_t preTokens = 65535; + int64_t nextTokens = 65535; + char layout[5] = {'B', 'N', 'S', 'D', 0}; + int64_t sparseMode = 0; + int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2; + int64_t blockSize = 0; + int64_t antiquantMode = 0; + bool softmaxLseFlag = false; + int64_t keyAntiquantMode = 0; + int64_t valueAntiquantMode = 0; + + // Step 5: launch the FusedInferAttentionScoreV2 kernel. + // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md + + GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2, + acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v + bcast_pse_tensor, nullptr, // pse, mask + nullptr, nullptr, // actSeqLen, actSeqLenkv + nullptr, nullptr, // deqScale1, quantScale1 + nullptr, nullptr, nullptr, // deqScale2, quantScale2, quantOffset2 + nullptr, nullptr, // antiquantScale, antiquantOffset + nullptr, // blockTable + nullptr, nullptr, // qPadSize, kvPadSize + nullptr, nullptr, // kAntiquantScale, kAntiQuantOffset + nullptr, nullptr, // vAntiquantScale, vAntiQuantOffset + nullptr, nullptr, nullptr, // kSharedPrefix, vSharedPrefix, actSharedLen + numHeads, scaleValue, // heads, scaleValue + preTokens, nextTokens, // preTokens, nextTokens + layout, // inputLayout + numKeyValueHeads, // numKVHeads + sparseMode, innerPrecise, // sparseMode, innerPrecise + blockSize, antiquantMode, // blockSize, antiquantMode + softmaxLseFlag, // softmaxLseFlag + keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode + acl_dst_f16_tensor, // attentionOut + nullptr // softmaxLse + ); + + // Step 6: post-processing, permute and cast to f32 + + int64_t new_dim[] = {0, 2, 1, 3}; + aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); + + if(ggml_cann_type_mapping(dst->type) != faDataType){ + ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool()); + perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize); + void* perm_out_f16_buffer = perm_out_f16_allocator.get(); + + int64_t* perm_out_f16_ne = dst->ne; + size_t perm_out_f16_nb[GGML_MAX_DIMS]; + perm_out_f16_nb[0] = faElemSize; + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1]; + } + aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor( + perm_out_f16_buffer, faDataType, faElemSize, + perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS); + aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS); + aclnn_cast(ctx, + acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); + ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor); + }else{ + // only need to permute + aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS); + } + ggml_cann_release_resources(ctx, acl_src0_f16_tensor, + acl_src1_f16_tensor, + acl_src2_f16_tensor, + acl_dst_f16_tensor, + acl_dst_tensor); + if(src3 != nullptr){ + ggml_cann_release_resources(ctx, bcast_pse_tensor); + } + }else{ + GGML_ABORT("Function is not implemented."); + } +} diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h old mode 100644 new mode 100755 index 15993cce6..80ce80bae --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -714,6 +714,21 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst); */ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst); +/** + * @brief Performs the Flash Attention extended operator using the CANN backend. + * + * @details This function implements the memory-efficient Flash Attention algorithm + * for computing scaled dot-product attention with hardware acceleration. + * The result is stored in the destination tensor `dst`. + * + * This operation is accelerated using the CANN backend to improve runtime performance. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the result will be stored. + * dst->op is expected to be `GGML_OP_FLASH_ATTN_EXT`. + */ +void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst); + /* * @brief A generic wrapper for ACL resources with custom deleter support. */ diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h old mode 100644 new mode 100755 diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp old mode 100644 new mode 100755 index 0cb7bbf17..c0ea26002 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -36,6 +36,7 @@ #include "ggml-backend-impl.h" #include "ggml-cann/aclnn_ops.h" #include "ggml-cann/common.h" +#include "ggml.h" #define GGML_COMMON_DECL_C @@ -1748,6 +1749,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, case GGML_OP_COUNT_EQUAL: ggml_cann_count_equal(ctx, dst); break; + case GGML_OP_FLASH_ATTN_EXT: + ggml_cann_flash_attn_ext(ctx, dst); + break; default: return false; } @@ -2035,6 +2039,15 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_TYPE_F16: case GGML_TYPE_F32: return true; + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: +#ifdef ASCEND_310P + // Q4 && Q8 per group is not suppor on 310p device + return false; +#endif + // only support contiguous for quantized types. + return ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op->src[1]); default: return false; } @@ -2168,6 +2181,38 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_PAD_REFLECT_1D: case GGML_OP_COUNT_EQUAL: return true; + case GGML_OP_FLASH_ATTN_EXT:{ + // derived from [ggml-cuda.cu] + if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){ + return false; + } + if(op->src[1]->type != GGML_TYPE_F16 && op->src[1]->type != GGML_TYPE_F32 && op->src[1]->type != GGML_TYPE_BF16){ + return false; + } + if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){ + return false; + } + if (op->src[1]->ne[0] != op->src[2]->ne[0]) { + // different head sizes of K and V are not supported yet + return false; + } + if (op->src[0]->ne[0] == 192) { + return false; + } + if (op->src[0]->ne[0] == 576) { + // DeepSeek MLA + return false; + } + if (op->src[0]->ne[3] != 1) { + return false; + } + float logitSoftcap = 0.0f; + memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float)); + if(logitSoftcap != 0.0f) { + return false; + } + return true; + } default: return false; } diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 1d4259dae..b3237eead 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -82,13 +82,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name) target_link_libraries(${GGML_CPU_NAME} PUBLIC memkind) endif() - if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR - CMAKE_GENERATOR_PLATFORM_LWR STREQUAL "arm64" OR - (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND - CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm.*|ARM64)$")) - + if (GGML_SYSTEM_ARCH STREQUAL "ARM") message(STATUS "ARM detected") - if (MSVC AND NOT CMAKE_C_COMPILER_ID STREQUAL "Clang") message(FATAL_ERROR "MSVC is not supported for ARM, use clang") else() @@ -170,12 +165,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name) endforeach() endif() endif() - elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR - (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND - CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64|amd64)$")) - + elseif (GGML_SYSTEM_ARCH STREQUAL "x86") message(STATUS "x86 detected") - if (MSVC) # instruction set detection for MSVC only if (GGML_NATIVE) @@ -299,7 +290,26 @@ function(ggml_add_cpu_backend_variant_impl tag_name) endif() endif() endif() - elseif ("${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "ppc64le " OR "${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "powerpc ") + + if (GGML_BACKEND_DL) + if (GGML_NATIVE) + # the feature check relies on ARCH_DEFINITIONS, but it is not set with GGML_NATIVE + message(FATAL_ERROR "GGML_NATIVE is not compatible with GGML_BACKEND_DL, consider using GGML_CPU_ALL_VARIANTS") + endif() + + # The feature detection code is compiled as a separate target so that + # it can be built without the architecture flags + # Since multiple variants of the CPU backend may be included in the same + # build, using set_source_files_properties() to set the arch flags is not possible + set(GGML_CPU_FEATS_NAME ${GGML_CPU_NAME}-feats) + add_library(${GGML_CPU_FEATS_NAME} OBJECT ggml-cpu/cpu-feats-x86.cpp) + target_include_directories(${GGML_CPU_FEATS_NAME} PRIVATE . .. ../include) + target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARCH_DEFINITIONS}) + target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED) + set_target_properties(${GGML_CPU_FEATS_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_link_libraries(${GGML_CPU_NAME} PRIVATE ${GGML_CPU_FEATS_NAME}) + endif() + elseif (GGML_SYSTEM_ARCH STREQUAL "PowerPC") message(STATUS "PowerPC detected") if (GGML_NATIVE) if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") @@ -325,9 +335,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name) list(APPEND ARCH_FLAGS -mcpu=${GGML_CPU_POWERPC_CPUTYPE}) endif() endif() - elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64") + elseif (GGML_SYSTEM_ARCH STREQUAL "loongarch64") message(STATUS "loongarch64 detected") - list(APPEND ARCH_FLAGS -march=loongarch64) if (GGML_LASX) list(APPEND ARCH_FLAGS -mlasx) @@ -335,16 +344,18 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_LSX) list(APPEND ARCH_FLAGS -mlsx) endif() - elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "riscv64") - message(STATUS "RISC-V detected") + elseif (GGML_SYSTEM_ARCH STREQUAL "riscv64") + message(STATUS "riscv64 detected") if (GGML_RVV) - if (GGML_RV_ZFH) - list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -DGGML_RV_ZFH -mabi=lp64d) + if (GGML_XTHEADVECTOR) + list(APPEND ARCH_FLAGS -march=rv64gc_xtheadvector -mabi=lp64d) + elseif (GGML_RV_ZFH) + list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -mabi=lp64d) else() list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d) endif() endif() - elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x") + elseif (GGML_SYSTEM_ARCH STREQUAL "s390x") message(STATUS "s390x detected") file(READ "/proc/cpuinfo" CPUINFO_CONTENTS) string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS}) @@ -477,25 +488,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name) target_compile_options(${GGML_CPU_NAME} PRIVATE ${ARCH_FLAGS}) target_compile_definitions(${GGML_CPU_NAME} PRIVATE ${ARCH_DEFINITIONS}) - if (GGML_BACKEND_DL) - if (GGML_NATIVE) - # the feature check relies on ARCH_DEFINITIONS, but it is not set with GGML_NATIVE - message(FATAL_ERROR "GGML_NATIVE is not compatible with GGML_BACKEND_DL, consider using GGML_CPU_ALL_VARIANTS") - endif() - - # The feature detection code is compiled as a separate target so that - # it can be built without the architecture flags - # Since multiple variants of the CPU backend may be included in the same - # build, using set_source_files_properties() to set the arch flags is not possible - set(GGML_CPU_FEATS_NAME ${GGML_CPU_NAME}-feats) - add_library(${GGML_CPU_FEATS_NAME} OBJECT ggml-cpu/cpu-feats-x86.cpp) - target_include_directories(${GGML_CPU_FEATS_NAME} PRIVATE . .. ../include) - target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARCH_DEFINITIONS}) - target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED) - set_target_properties(${GGML_CPU_FEATS_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) - target_link_libraries(${GGML_CPU_NAME} PRIVATE ${GGML_CPU_FEATS_NAME}) - endif() - if (EMSCRIPTEN) set_target_properties(${GGML_CPU_NAME} PROPERTIES COMPILE_FLAGS "-msimd128") endif() diff --git a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp index 8ff6d64a4..0a3ff867c 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp @@ -1191,7 +1191,7 @@ static void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c } } return; -#elif defined(__riscv_v_intrinsic) +#elif defined __riscv_v if (__riscv_vlenb() >= QK4_0) { const size_t vl = QK4_0; @@ -3783,7 +3783,7 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c } return; } -#elif defined(__riscv_v_intrinsic) +#elif defined __riscv_v if (__riscv_vlenb() >= QK4_0) { const size_t vl = QK4_0; diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h index e4af07635..b3f1b5ca7 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -320,21 +320,17 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) #ifdef __wasm_simd128__ #include -#else +#endif + #ifdef __POWER9_VECTOR__ #include -#else +#endif + #if defined(_MSC_VER) || defined(__MINGW32__) #include -#else -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__) -#if !defined(__riscv) +#elif defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__) #include #endif -#endif -#endif -#endif -#endif #ifdef __riscv_v_intrinsic #include diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index a89ce9bb1..40bded476 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -883,7 +883,7 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); #endif } -#elif defined(__riscv_v_intrinsic) +#elif defined(__riscv_v) size_t vl = QK8_0; @@ -1221,7 +1221,7 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); #endif } -#elif defined(__riscv_v_intrinsic) +#elif defined(__riscv_v) size_t vl = QK8_1; @@ -2384,7 +2384,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi } sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); -#elif defined(__riscv_v_intrinsic) +#elif defined(__riscv_v) size_t vl = qk / 2; for (; ib < nb; ++ib) { @@ -2774,7 +2774,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi } sumf = hsum_float_8(acc) + summs; -#elif defined(__riscv_v_intrinsic) +#elif defined(__riscv_v) size_t vl = qk / 2; for (; ib < nb; ++ib) { @@ -3121,7 +3121,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi } sumf = hsum_float_8(acc); -#elif defined(__riscv_v_intrinsic) +#elif defined(__riscv_v) size_t vl; size_t vlenb = __riscv_vlenb(); @@ -3460,7 +3460,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi } sumf = hsum_float_8(acc) + summs; -#elif defined(__riscv_v_intrinsic) +#elif defined(__riscv_v) size_t vl; size_t vlenb = __riscv_vlenb(); @@ -3897,7 +3897,7 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi } sumf = hsum_float_8(accum); -#elif defined(__riscv_v_intrinsic) +#elif defined(__riscv_v) size_t vl = qk; for (; ib < nb; ++ib) { @@ -5100,14 +5100,111 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi *s = sumf; -#elif defined __riscv_v_intrinsic +#elif defined __riscv_xtheadvector + + float sumf = 0; + uint8_t atmp[16]; + + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + uint8_t *patmp = atmp; + int vsums; + int tmp; + __asm__ __volatile__( + "th.vsetvli zero, %[vl16], e8, m1\n\t" + "th.vmv.v.x v8, zero\n\t" + "th.vlb.v v1, (%[sc])\n\t" + "th.vand.vi v0, v1, 0xF\n\t" + "th.vsrl.vi v1, v1, 4\n\t" + "th.vsb.v v0, (%[scale])\n\t" + "th.vwaddu.vx v16, v1, zero\n\t" + "th.vsetvli zero, %[vl16], e16, m2\n\t" + "th.vlh.v v2, (%[bsums])\n\t" + "th.vwmul.vv v4, v16, v2\n\t" + "th.vsetvli zero, %[vl16], e32, m4\n\t" + "th.vredsum.vs v8, v4, v8\n\t" + "th.vmv.x.s %[vsums], v8" + : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums) + : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums) + , [vl16] "r" (16) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + sumf += dmin * vsums; + int isum = 0; + + for (int j = 0; j < QK_K/128; ++j) { + __asm__ __volatile__( + "th.vsetvli zero, %[vl32], e8, m2\n\t" + "th.vlb.v v0, (%[q2])\n\t" + "th.vsrl.vi v2, v0, 2\n\t" + "th.vsrl.vi v4, v0, 4\n\t" + "th.vsrl.vi v6, v0, 6\n\t" + "th.vand.vi v0, v0, 0x3\n\t" + "th.vand.vi v2, v2, 0x3\n\t" + "th.vand.vi v4, v4, 0x3\n\t" + "th.vsetvli zero, %[vl128], e8, m8\n\t" + "th.vlb.v v8, (%[q8])\n\t" + "th.vsetvli zero, %[vl64], e8, m4\n\t" + "th.vwmul.vv v16, v0, v8\n\t" + "th.vwmul.vv v24, v4, v12\n\t" + "th.vsetvli zero, %[vl16], e16, m2\n\t" + "th.vmv.v.x v0, zero\n\t" + "th.vwredsum.vs v10, v16, v0\n\t" + "th.vwredsum.vs v9, v18, v0\n\t" + "th.vwredsum.vs v8, v20, v0\n\t" + "th.vwredsum.vs v7, v22, v0\n\t" + "th.vwredsum.vs v11, v24, v0\n\t" + "th.vwredsum.vs v12, v26, v0\n\t" + "th.vwredsum.vs v13, v28, v0\n\t" + "th.vwredsum.vs v14, v30, v0\n\t" + "li %[tmp], 4\n\t" + "th.vsetvli zero, %[tmp], e32, m1\n\t" + "th.vslideup.vi v10, v9, 1\n\t" + "th.vslideup.vi v8, v7, 1\n\t" + "th.vslideup.vi v11, v12, 1\n\t" + "th.vslideup.vi v13, v14, 1\n\t" + "th.vslideup.vi v10, v8, 2\n\t" + "th.vslideup.vi v11, v13, 2\n\t" + "li %[tmp], 8\n\t" + "th.vsetvli zero, %[tmp], e32, m2\n\t" + "th.vlbu.v v12, (%[scale])\n\t" + "th.vmul.vv v10, v10, v12\n\t" + "th.vredsum.vs v0, v10, v0\n\t" + "th.vmv.x.s %[tmp], v0\n\t" + "add %[isum], %[isum], %[tmp]" + : [tmp] "=&r" (tmp), [isum] "+&r" (isum) + : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8) + , [vl16] "r" (16), [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q2 += 32; q8 += 128; patmp += 8; + } + + sumf += dall * isum; + } + + *s = sumf; + +#elif defined __riscv_v + + float sumf = 0; + uint8_t atmp[16]; const int vector_length = __riscv_vlenb() * 8; - float sumf = 0; - uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }; - uint8_t atmp[16]; switch (vector_length) { case 256: @@ -6137,14 +6234,141 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi *s = sumf; -#elif defined __riscv_v_intrinsic +#elif defined __riscv_xtheadvector - uint32_t aux[3]; uint32_t utmp[4]; - - const int vector_length = __riscv_vlenb() * 8; float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + int8_t * scale = (int8_t *)utmp; + int tmp; + __asm__ __volatile__( + "li %[tmp], 12\n\t" + "th.vsetvli zero, %[tmp], e8, m1\n\t" + "th.vlb.v v0, (%[s6b])\n\t" + "th.vmv.v.v v2, v0\n\t" + "li %[tmp], 2\n\t" + "th.vsetvli zero, %[tmp], e64, m1\n\t" + "th.vmv.v.x v9, %[sh]\n\t"\ + "th.vslidedown.vi v1, v0, 1\n\t" + "th.vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4} + "th.vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]} + "li %[tmp], 4\n\t" + "th.vsetvli zero, %[tmp], e32, m1\n\t" + "th.vid.v v9\n\t" + "th.vmv.x.s %[tmp], v1\n\t" + "th.vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6} + "th.vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]} + "th.vsrl.vv v4, v1, v9\n\t" + "th.vsrl.vv v2, v0, v8\n\t" + "th.vand.vx v5, v4, %[kmask1]\n\t" + "th.vand.vx v3, v2, %[kmask2]\n\t" + "th.vsll.vi v6, v5, 4\n\t" + "th.vor.vv v7, v6, v3\n\t" + "li %[tmp], 16\n\t" + "th.vsetvli zero, %[tmp], e8, m1\n\t" + "th.vsub.vx v0, v7, %[c]\n\t" + "th.vsb.v v0, (%[scale])" + : [tmp] "=&r" (tmp) + : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32) + , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + + uint8_t m = 1; + int isum = 0; + for (int j = 0; j < QK_K; j += 128) { + __asm__ __volatile__( + // fixme: use v0p7 mask layout directly + "th.vsetvli zero, %[vl32], e8, m2\n\t" + "th.vlb.v v8, (%[q3])\n\t" + "th.vsrl.vi v10, v8, 2\n\t" + "th.vsrl.vi v12, v8, 4\n\t" + "th.vsrl.vi v14, v8, 6\n\t" + "th.vand.vi v8, v8, 3\n\t" + "th.vand.vi v10, v10, 3\n\t" + "th.vand.vi v12, v12, 3\n\t" + "th.vlb.v v2, (%[qh])\n\t" + "th.vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "th.vmseq.vx v0, v4, zero\n\t" + "th.vadd.vi v8, v8, -4, v0.t\n\t" + "th.vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "th.vmseq.vx v0, v4, zero\n\t" + "th.vadd.vi v10, v10, -4, v0.t\n\t" + "th.vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "th.vmseq.vx v0, v4, zero\n\t" + "th.vadd.vi v12, v12, -4, v0.t\n\t" + "th.vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "th.vmseq.vx v0, v4, zero\n\t" + "th.vadd.vi v14, v14, -4, v0.t\n\t" + "th.vsetvli zero, %[vl128], e8, m8\n\t" + "th.vlb.v v0, (%[q8])\n\t" + "th.vsetvli zero, %[vl64], e8, m4\n\t" + "th.vwmul.vv v16, v0, v8\n\t" + "th.vwmul.vv v24, v4, v12\n\t" + "li %[tmp], 16\n\t" + "th.vsetvli zero, %[tmp], e16, m2\n\t" + "th.vmv.v.x v0, zero\n\t" + "th.vwredsum.vs v10, v16, v0\n\t" + "th.vwredsum.vs v9, v18, v0\n\t" + "th.vwredsum.vs v8, v20, v0\n\t" + "th.vwredsum.vs v7, v22, v0\n\t" + "th.vwredsum.vs v11, v24, v0\n\t" + "th.vwredsum.vs v12, v26, v0\n\t" + "th.vwredsum.vs v13, v28, v0\n\t" + "th.vwredsum.vs v14, v30, v0\n\t" + "li %[tmp], 4\n\t" + "th.vsetvli zero, %[tmp], e32, m1\n\t" + "th.vslideup.vi v10, v9, 1\n\t" + "th.vslideup.vi v8, v7, 1\n\t" + "th.vslideup.vi v11, v12, 1\n\t" + "th.vslideup.vi v13, v14, 1\n\t" + "th.vslideup.vi v10, v8, 2\n\t" + "th.vslideup.vi v11, v13, 2\n\t" + "li %[tmp], 8\n\t" + "th.vsetvli zero, %[tmp], e32, m2\n\t" + "th.vlb.v v12, (%[scale])\n\t" + "th.vmul.vv v10, v10, v12\n\t" + "th.vredsum.vs v0, v10, v0\n\t" + "th.vmv.x.s %[tmp], v0\n\t" + "add %[isum], %[isum], %[tmp]" + : [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum) + : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32) + , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q3 += 32; q8 += 128; scale += 8; + } + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + sumf += d * isum; + } + + *s = sumf; + +#elif defined __riscv_v + + uint32_t utmp[4]; + float sumf = 0; + uint32_t aux[3]; + const int vector_length = __riscv_vlenb() * 8; + switch (vector_length) { case 256: for (int i = 0; i < nb; ++i) { @@ -6331,7 +6555,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi "vslideup.vi v13, v14, 1\n\t" "vslideup.vi v10, v8, 2\n\t" "vslideup.vi v11, v13, 2\n\t" - "vsetivli zero, 8, e32, m2\n\t"\ + "vsetivli zero, 8, e32, m2\n\t" "vle8.v v15, (%[scale])\n\t" "vsext.vf4 v12, v15\n\t" "vmul.vv v10, v10, v12\n\t" @@ -6771,7 +6995,11 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); +#ifdef __ARM_FEATURE_MATMUL_INT8 + assert((nrc == 2) || (nrc == 1)); +#else assert(nrc == 1); +#endif UNUSED(nrc); UNUSED(bx); UNUSED(by); @@ -6788,6 +7016,146 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi uint32_t utmp[4]; +#if defined(__ARM_FEATURE_MATMUL_INT8) + if (nrc == 2) { + const block_q4_K * GGML_RESTRICT x0 = x; + const block_q4_K * GGML_RESTRICT x1 = (const block_q4_K *) ((const uint8_t *)vx + bx); + const block_q8_K * GGML_RESTRICT y0 = y; + const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by); + + const uint8x16_t m4b = vdupq_n_u8(0x0f); + + float32x4_t vfsum = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) { + const uint8_t * GGML_RESTRICT qx0 = x0->qs; + const uint8_t * GGML_RESTRICT qx1 = x1->qs; + const int8_t * GGML_RESTRICT qy0 = y0->qs; + const int8_t * GGML_RESTRICT qy1 = y1->qs; + + // decode scales and mins + int8_t x0_scales[8], x1_scales[8]; + int16x8_t x0_mins, x1_mins; + { + uint32_t scales_mins[3]; + memcpy(scales_mins, x0->scales, 12); + const uint32_t mins_0_3 = scales_mins[1] & kmask1; + const uint32_t mins_4_7 = ((scales_mins[2] >> 4) & kmask2) | (((scales_mins[1] >> 6) & kmask3) << 4); + const uint32x2_t mins = {mins_0_3, mins_4_7}; + x0_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins))); + uint32_t scales[2]; + scales[0] = scales_mins[0] & kmask1; // scales 0~3 + scales[1] = (scales_mins[2] & kmask2) | (((scales_mins[0] >> 6) & kmask3) << 4); // scales 4~7 + memcpy(x0_scales, scales, 8); + } + { + uint32_t scales_mins[3]; + memcpy(scales_mins, x1->scales, 12); + const uint32_t mins_0_3 = scales_mins[1] & kmask1; + const uint32_t mins_4_7 = ((scales_mins[2] >> 4) & kmask2) | (((scales_mins[1] >> 6) & kmask3) << 4); + const uint32x2_t mins = {mins_0_3, mins_4_7}; + x1_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins))); + uint32_t scales[2]; + scales[0] = scales_mins[0] & kmask1; // scales 0~3 + scales[1] = (scales_mins[2] & kmask2) | (((scales_mins[0] >> 6) & kmask3) << 4); // scales 4~7 + memcpy(x1_scales, scales, 8); + } + + int32x4_t visum = {0}; + + // process 64 data points per iteration, totally 256 data points + for (int j = 0; j < QK_K / 64; ++j, qx0 += 32, qx1 += 32, qy0 += 64, qy1 += 64) { + const int8x16x4_t vy0 = vld1q_s8_x4(qy0); + const int8x16x4_t vy1 = vld1q_s8_x4(qy1); + + int8x16_t vx0[4], vx1[4]; + { + const uint8x16x2_t vv = vld1q_u8_x2(qx0); + vx0[0] = vreinterpretq_s8_u8(vandq_u8(vv.val[0], m4b)); + vx0[1] = vreinterpretq_s8_u8(vandq_u8(vv.val[1], m4b)); + vx0[2] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[0], 4)); + vx0[3] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[1], 4)); + } + { + const uint8x16x2_t vv = vld1q_u8_x2(qx1); + vx1[0] = vreinterpretq_s8_u8(vandq_u8(vv.val[0], m4b)); + vx1[1] = vreinterpretq_s8_u8(vandq_u8(vv.val[1], m4b)); + vx1[2] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[0], 4)); + vx1[3] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[1], 4)); + } + + // process 32 data points (share same block scale) per iteration + for (int k = 0; k < 2; ++k) { + const int blk = j * 2 + k; + const int32x4_t block_scale = { + x0_scales[blk], + x0_scales[blk], + x1_scales[blk], + x1_scales[blk], + }; + + int32x4_t vr = {0}; + for (int l = 0; l < 2; ++l) { + const int idx = k * 2 + l; + const int64x2_t vx0_s64 = vreinterpretq_s64_s8(vx0[idx]); + const int64x2_t vx1_s64 = vreinterpretq_s64_s8(vx1[idx]); + const int64x2_t vy0_s64 = vreinterpretq_s64_s8(vy0.val[idx]); + const int64x2_t vy1_s64 = vreinterpretq_s64_s8(vy1.val[idx]); + const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vx0_s64, vx1_s64)); + const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vx0_s64, vx1_s64)); + const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vy0_s64, vy1_s64)); + const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vy0_s64, vy1_s64)); + vr = vmmlaq_s32(vr, vx_l, vy_l); + vr = vmmlaq_s32(vr, vx_h, vy_h); + } + // apply block scale, will NOT overflow + // block_scale * sum_256(int4*int8) <= 2^(8+8+4+8) = 28 bits + visum = vmlaq_s32(visum, vr, block_scale); + } + } + + // adjust bias, apply superblock scale + { + int32_t bias[4]; + // no obvious uplift from sve sdot-16, just use neon mul add + const int16x8_t y0_sums = vpaddq_s16(vld1q_s16(y0->bsums), vld1q_s16(y0->bsums+8)); + const int16x8_t y1_sums = vpaddq_s16(vld1q_s16(y1->bsums), vld1q_s16(y1->bsums+8)); + bias[0] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y0_sums), vget_low_s16(x0_mins)), + vmull_s16(vget_high_s16(y0_sums), vget_high_s16(x0_mins)))); + bias[1] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y1_sums), vget_low_s16(x0_mins)), + vmull_s16(vget_high_s16(y1_sums), vget_high_s16(x0_mins)))); + bias[2] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y0_sums), vget_low_s16(x1_mins)), + vmull_s16(vget_high_s16(y0_sums), vget_high_s16(x1_mins)))); + bias[3] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y1_sums), vget_low_s16(x1_mins)), + vmull_s16(vget_high_s16(y1_sums), vget_high_s16(x1_mins)))); + const float32x4_t dmins = { + GGML_FP16_TO_FP32(x0->dmin) * y0->d, + GGML_FP16_TO_FP32(x0->dmin) * y1->d, + GGML_FP16_TO_FP32(x1->dmin) * y0->d, + GGML_FP16_TO_FP32(x1->dmin) * y1->d, + }; + vfsum = vmlsq_f32(vfsum, vcvtq_f32_s32(vld1q_s32(bias)), dmins); + + const float32x4_t superblock_scale = { + GGML_FP16_TO_FP32(x0->d) * y0->d, + GGML_FP16_TO_FP32(x0->d) * y1->d, + GGML_FP16_TO_FP32(x1->d) * y0->d, + GGML_FP16_TO_FP32(x1->d) * y1->d, + }; + vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale); + } + } + + // vfsum = ABCD -> ACBD + // AC -> s, BD -> (s+bs) + vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2)); + vst1_f32(s, vget_low_f32 (vfsum)); + vst1_f32(s + bs, vget_high_f32(vfsum)); + + return; + } +#endif + #ifdef __ARM_FEATURE_SVE float sumf = 0; for (int i = 0; i < nb; ++i) { @@ -7180,14 +7548,130 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); -#elif defined __riscv_v_intrinsic +#elif defined __riscv_xtheadvector const uint8_t * scales = (const uint8_t*)&utmp[0]; const uint8_t * mins = (const uint8_t*)&utmp[2]; - const int vector_length = __riscv_vlenb() * 8; float sumf = 0; + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + int tmp, tmp2, sumi; + __asm__ __volatile__( + "li %[t1], 12\n\t" + "th.vsetvli zero, %[t1], e8, m1\n\t" + "th.vlb.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]} + "li %[t1], 4\n\t" + "th.vsetvli zero, %[t1], e32, m1\n\t" + "th.vslidedown.vi v2, v1, 2\n\t" + "th.vmv.v.v v3, v2\n\t" + "th.vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]} + "li %[t1], 2\n\t" + "th.vsetvli zero, %[t1], e32, m1\n\t" + "th.vmv.v.i v4, 4\n\t" + "th.vand.vx v8, v1, %[kmask1]\n\t" + "th.vslide1up.vx v5, v4, zero\n\t" // {0, 4} + "th.vsrl.vi v6, v1, 6\n\t" + "th.vsrl.vv v7, v2, v5\n\t" + "th.vand.vx v0, v6, %[kmask3]\n\t" + "th.vand.vx v2, v7, %[kmask2]\n\t" + "th.vsll.vi v6, v0, 4\n\t" + "li %[t2], 8\n\t" + "addi %[t1], %[utmp], 4\n\t" + "th.vor.vv v1, v6, v2\n\t" + "th.vssw.v v8, (%[utmp]), %[t2]\n\t" + "th.vssw.v v1, (%[t1]), %[t2]\n\t" + "th.vsetvli zero, zero, e32, m2\n\t" // vl == 8 + "th.vlw.v v2, (%[bsums])\n\t" + "th.vsetvli zero, %[t2], e16, m1\n\t" + "th.vnsrl.vi v0, v2, 0\n\t" + "th.vnsrl.vi v1, v2, 16\n\t" + "th.vadd.vv v2, v0, v1\n\t" + "th.vlbu.v v4, (%[mins])\n\t" + "th.vwmul.vv v6, v4, v2\n\t" + "th.vmv.v.x v0, zero\n\t" + "th.vsetvli zero, %[t2], e32, m2\n\t" + "th.vredsum.vs v0, v6, v0\n\t" + "th.vmv.x.s %[sumi], v0" + : [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi) + : [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp) + , [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1) + , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + sumf -= dmin * sumi; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + sumi = 0; + const uint8_t * scale = scales; + + for (int j = 0; j < QK_K/128; ++j) { + int vl128 = 128, vl64 = 64, vl32 = 32; + __asm__ __volatile__( + "th.vsetvli zero, %[vl128], e8, m8\n\t" + "th.vlb.v v8, (%[q8])\n\t" + "th.vsetvli zero, %[vl64], e8, m4\n\t" + "th.vlb.v v0, (%[q4])\n\t" + "th.vsrl.vi v4, v0, 4\n\t" + "th.vand.vi v0, v0, 0xF\n\t" + "th.vsetvli zero, %[vl32], e8, m2\n\t" + "th.vwmul.vv v28, v6, v14\n\t" + "th.vwmul.vv v20, v4, v10\n\t" + "th.vwmul.vv v24, v2, v12\n\t" + "th.vwmul.vv v16, v0, v8\n\t" + "li %[tmp], 4\n\t" + "th.vsetvli zero, %[tmp], e32, m1\n\t" + "th.vlbu.v v1, (%[scale])\n\t" + "th.vmv.v.x v0, zero\n\t" + "th.vsetvli zero, %[vl32], e16, m4\n\t" + "th.vwredsum.vs v6, v24, v0\n\t" + "th.vwredsum.vs v7, v28, v0\n\t" + "th.vwredsum.vs v4, v16, v0\n\t" + "th.vwredsum.vs v5, v20, v0\n\t" + "th.vsetvli zero, %[tmp], e32, m1\n\t" + "th.vslideup.vi v6, v7, 1\n\t" + "th.vslideup.vi v4, v5, 1\n\t" + "th.vslideup.vi v4, v6, 2\n\t" + "th.vmul.vv v8, v4, v1\n\t" + "th.vredsum.vs v0, v8, v0\n\t" + "th.vmv.x.s %[tmp], v0\n\t" + "add %[sumi], %[sumi], %[tmp]" + : [tmp] "=&r" (tmp), [sumi] "+&r" (sumi) + : [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32) + , [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + + q4 += 64; q8 += 128; scale += 4; + } + + sumf += d * sumi; + + } + + *s = sumf; + +#elif defined __riscv_v + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + float sumf = 0; + const int vector_length = __riscv_vlenb() * 8; + switch (vector_length) { case 256: for (int i = 0; i < nb; ++i) { @@ -8074,7 +8558,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi *s = sumf; -#elif defined __riscv_v_intrinsic +#elif defined __riscv_v const uint8_t * scales = (const uint8_t*)&utmp[0]; const uint8_t * mins = (const uint8_t*)&utmp[2]; @@ -9232,11 +9716,92 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } *s = sumf; -#elif defined __riscv_v_intrinsic +#elif defined __riscv_xtheadvector - const int vector_length = __riscv_vlenb() * 8; float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + int sum_t = 0; + int t0; + + for (int j = 0; j < QK_K/128; ++j) { + __asm__ __volatile__( + "th.vsetvli zero, %[vl32], e8, m2\n\t" // vl == 32 + "th.vlb.v v4, (%[qh])\n\t" + "th.vsll.vi v0, v4, 4\n\t" + "th.vsll.vi v2, v4, 2\n\t" + "th.vsrl.vi v6, v4, 2\n\t" + "th.vsetvli zero, %[vl64], e8, m4\n\t" // vl == 64 + "th.vlb.v v8, (%[q6])\n\t" + "th.vsrl.vi v12, v8, 4\n\t" + "th.vand.vi v8, v8, 0xF\n\t" + "th.vsetvli zero, %[vl128], e8, m8\n\t" // vl == 128 + "th.vand.vx v0, v0, %[mask]\n\t" + "th.vor.vv v8, v8, v0\n\t" + "th.vlb.v v0, (%[q8])\n\t" + "th.vsub.vx v8, v8, %[vl32]\n\t" + "th.vsetvli zero, %[vl64], e8, m4\n\t" // vl == 64 + "th.vwmul.vv v16, v0, v8\n\t" + "th.vwmul.vv v24, v4, v12\n\t" + "li %[t0], 16\n\t" + "th.vsetvli zero, %[t0], e16, m2\n\t" // vl == 16 + "th.vmv.v.x v0, zero\n\t" + "th.vwredsum.vs v10, v16, v0\n\t" + "th.vwredsum.vs v9, v18, v0\n\t" + "th.vwredsum.vs v8, v20, v0\n\t" + "th.vwredsum.vs v7, v22, v0\n\t" + "th.vwredsum.vs v11, v24, v0\n\t" + "th.vwredsum.vs v12, v26, v0\n\t" + "th.vwredsum.vs v13, v28, v0\n\t" + "th.vwredsum.vs v14, v30, v0\n\t" + "li %[t0], 4\n\t" + "th.vsetvli zero, %[t0], e32, m1\n\t" // vl == 4 + "th.vslideup.vi v10, v9, 1\n\t" + "th.vslideup.vi v8, v7, 1\n\t" + "th.vslideup.vi v11, v12, 1\n\t" + "th.vslideup.vi v13, v14, 1\n\t" + "th.vslideup.vi v10, v8, 2\n\t" + "th.vslideup.vi v11, v13, 2\n\t" + "li %[t0], 8\n\t" + "th.vsetvli zero, %[t0], e32, m2\n\t" // vl == 8 + "th.vlb.v v4, (%[scale])\n\t" + "th.vmul.vv v2, v4, v10\n\t" + "th.vredsum.vs v0, v2, v0\n\t" + "th.vmv.x.s %[t0], v0\n\t" + "add %[sumi], %[sumi], %[t0]" + : [sumi] "+&r" (sum_t), [t0] "=&r" (t0) + : [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale) + , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) + , [mask] "r" (0x30) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q6 += 64; qh += 32; q8 += 128; scale += 8; + } + + sumf += d * sum_t; + + } + + *s = sumf; + +#elif defined __riscv_v + + float sumf = 0; + const int vector_length = __riscv_vlenb() * 8; + switch (vector_length) { case 256: for (int i = 0; i < nb; ++i) { diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 46f75ad97..c7426df2b 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -270,7 +270,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .from_float = quantize_row_q4_K, .vec_dot = ggml_vec_dot_q4_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, +#if defined (__ARM_FEATURE_MATMUL_INT8) + .nrows = 2, +#else .nrows = 1, +#endif }, [GGML_TYPE_Q5_K] = { .from_float = quantize_row_q5_K, @@ -2414,12 +2418,32 @@ static bool ggml_thread_apply_priority(int32_t prio) { // This is up to the applications. DWORD p = THREAD_PRIORITY_NORMAL; switch (prio) { + case GGML_SCHED_PRIO_LOW: p = THREAD_PRIORITY_BELOW_NORMAL; break; case GGML_SCHED_PRIO_NORMAL: p = THREAD_PRIORITY_NORMAL; break; case GGML_SCHED_PRIO_MEDIUM: p = THREAD_PRIORITY_ABOVE_NORMAL; break; case GGML_SCHED_PRIO_HIGH: p = THREAD_PRIORITY_HIGHEST; break; case GGML_SCHED_PRIO_REALTIME: p = THREAD_PRIORITY_TIME_CRITICAL; break; } + if (prio != GGML_SCHED_PRIO_LOW) { + // Tell Windows that this thread should not be throttled (needs its own CPU core). + // Newer Windows 11 versions aggresively park (offline) CPU cores and often place + // all our threads onto the first 4 cores which results in terrible performance with + // n_threads > 4 + #if _WIN32_WINNT >= 0x0602 + THREAD_POWER_THROTTLING_STATE t; + ZeroMemory(&t, sizeof(t)); + t.Version = THREAD_POWER_THROTTLING_CURRENT_VERSION; + t.ControlMask = THREAD_POWER_THROTTLING_EXECUTION_SPEED; + t.StateMask = 0; + + if (!SetThreadInformation(GetCurrentThread(), ThreadPowerThrottling, &t, sizeof(t))) { + GGML_LOG_DEBUG("failed to disable thread power throttling %d : (%d)\n", prio, (int) GetLastError()); + return false; + } + #endif + } + if (prio == GGML_SCHED_PRIO_NORMAL) { // Keep inherited policy/priority return true; @@ -2447,6 +2471,8 @@ static bool ggml_thread_apply_priority(int32_t prio) { struct sched_param p; int32_t policy = SCHED_OTHER; switch (prio) { + // TODO: there seems to be no way to set lower prio on Apple platforms + case GGML_SCHED_PRIO_LOW: policy = SCHED_OTHER; p.sched_priority = 0; break; case GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break; case GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break; case GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break; @@ -2503,6 +2529,7 @@ static bool ggml_thread_apply_priority(int32_t prio) { struct sched_param p; int32_t policy = SCHED_OTHER; switch (prio) { + case GGML_SCHED_PRIO_LOW: policy = SCHED_BATCH; p.sched_priority = 0; break; case GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break; case GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break; case GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break; @@ -3484,6 +3511,19 @@ void ggml_cpu_init(void) { const uint64_t t_end = ggml_time_us(); UNUSED(t_end); GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0); + +#ifdef GGML_USE_OPENMP + //if (!getenv("OMP_WAIT_POLICY")) { + // // set the wait policy to active, so that OpenMP threads don't sleep + // putenv("OMP_WAIT_POLICY=active"); + //} + + if (!getenv("KMP_BLOCKTIME")) { + // set the time to wait before sleeping a thread + // this is less aggressive than setting the wait policy to active, but should achieve similar results in most cases + putenv("KMP_BLOCKTIME=200"); // 200ms + } +#endif } #if defined(__ARM_ARCH) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 26501b711..d8de7531b 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7633,39 +7633,83 @@ static void ggml_compute_forward_ssm_scan_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - for (int i3 = 0; i3 < n_s; ++i3) { - for (int i2 = 0; i2 < n_t; ++i2) { - const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} - const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} - const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} - const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} - float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} + #ifdef __ARM_FEATURE_SVE + for (int i3 = 0; i3 < n_s; ++i3) { + for (int i2 = 0; i2 < n_t; ++i2) { + const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} + const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} + const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} + const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} + const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} + float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} - // use the output as the source for the next token-wise iterations - if (i2 > 0) { s0 = s; } + // use the output as the source for the next token-wise iterations + if (i2 > 0) { s0 = s; } - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 - float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; - float x_dt = x[i1] * dt_soft_plus; - float sumf = 0.0f; - // d_state - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - // state = prev_state * dA + dB * x - float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); - // y = rowwise_dotprod(state, C) - sumf += state * C[i0]; - s[i] = state; + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; + float x_dt = x[i1] * dt_soft_plus; + svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt); + svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus); + svfloat32_t r1_vector = GGML_F32_VEC_ZERO; + + for (int64_t k = 0; k < nc; k += svcntw()) { + svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]); + svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]); + svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]); + svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]); + + svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA); + t1 = exp_ps_sve(svptrue_b32(), t1); + svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB); + + vs0 = GGML_F32_VEC_FMA(vs0, t1, t2); + r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector); + + GGML_F32_VEC_STORE(&s[i1*nc + k], vs0); + } + y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector); } - y[i1] = sumf; } } - } + #else + for (int i3 = 0; i3 < n_s; ++i3) { + for (int i2 = 0; i2 < n_t; ++i2) { + const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} + const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} + const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} + const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} + const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} + float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} + + // use the output as the source for the next token-wise iterations + if (i2 > 0) { s0 = s; } + + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 + float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; + float x_dt = x[i1] * dt_soft_plus; + float sumf = 0.0f; + // d_state + for (int i0 = 0; i0 < nc; ++i0) { + int i = i0 + i1*nc; + // state = prev_state * dA + dB * x + float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[i0]; + s[i] = state; + } + y[i1] = sumf; + } + } + } + #endif } void ggml_compute_forward_ssm_scan( @@ -8070,6 +8114,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32( #define GGML_F32X_MUL GGML_F32x16_MUL #define GGML_F32X_FMA GGML_F32x16_FMA #define WKV_VECTOR_SIZE 16 + #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__) + #define GGML_F32X GGML_F32xt + #define GGML_F32X_SET1 GGML_F32xt_SET1 + #define GGML_F32X_LOAD GGML_F32xt_LOAD + #define GGML_F32X_STORE GGML_F32xt_STORE + #define GGML_F32X_MUL GGML_F32xt_MUL + #define GGML_F32X_FMA GGML_F32xt_FMA + #define WKV_VECTOR_SIZE 8 #elif defined(__ARM_NEON) && defined(__aarch64__) #define GGML_F32X GGML_F32x4 #define GGML_F32X_SET1 GGML_F32x4_SET1 @@ -8080,8 +8132,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32( #define WKV_VECTOR_SIZE 4 #endif + int wkv_vector_size; #ifdef WKV_VECTOR_SIZE - const int64_t vec_count = head_size / WKV_VECTOR_SIZE; + #if defined(__ARM_FEATURE_SVE) + wkv_vector_size = svcntw(); + #else + wkv_vector_size = WKV_VECTOR_SIZE; + #endif + const int64_t vec_count = head_size / wkv_vector_size; for (int64_t t = 0; t < T; t++) { size_t t_offset = t * t_stride; @@ -8111,7 +8169,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32( GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val); for (int64_t j = 0; j < vec_count; j++) { - size_t base_j = j * WKV_VECTOR_SIZE; + size_t base_j = j * wkv_vector_size; size_t t_h_j_offset = t_h_offset + base_j; size_t h_2d_i_j_offset = h_2d_i_offset + base_j; @@ -8136,7 +8194,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32( } // Handle remaining elements, this will not be used. - for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) { + for (int64_t j = vec_count * wkv_vector_size; j < head_size; j++) { size_t t_h_j_offset = t_h_offset + j; size_t h_2d_i_j_offset = h_2d_i_offset + j; float v_val = v[t_h_j_offset]; @@ -8272,6 +8330,14 @@ static void ggml_compute_forward_gla_f32( #define GGML_F32X_MUL GGML_F32x16_MUL #define GGML_F32X_FMA GGML_F32x16_FMA #define GLA_VECTOR_SIZE 16 + #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__) + #define GGML_F32X GGML_F32xt + #define GGML_F32X_SET1 GGML_F32xt_SET1 + #define GGML_F32X_LOAD GGML_F32xt_LOAD + #define GGML_F32X_STORE GGML_F32xt_STORE + #define GGML_F32X_MUL GGML_F32xt_MUL + #define GGML_F32X_FMA GGML_F32xt_FMA + #define GLA_VECTOR_SIZE 8 #elif defined(__ARM_NEON) && defined(__aarch64__) #define GGML_F32X GGML_F32x4 #define GGML_F32X_SET1 GGML_F32x4_SET1 @@ -8282,8 +8348,14 @@ static void ggml_compute_forward_gla_f32( #define GLA_VECTOR_SIZE 4 #endif + int gla_vector_size; #ifdef GLA_VECTOR_SIZE - const int64_t vec_count = head_size / GLA_VECTOR_SIZE; + #if defined(__ARM_FEATURE_SVE) + gla_vector_size = svcntw(); + #else + gla_vector_size = GLA_VECTOR_SIZE; + #endif + const int64_t vec_count = head_size / gla_vector_size; for (int64_t t = 0; t < T; t++) { size_t t_offset = t * t_stride; @@ -8310,7 +8382,7 @@ static void ggml_compute_forward_gla_f32( GGML_F32X g_vec = GGML_F32X_SET1(g_val); for (int64_t j = 0; j < vec_count; j++) { - size_t base_j = j * GLA_VECTOR_SIZE; + size_t base_j = j * gla_vector_size; size_t t_h_j_offset = t_h_offset + base_j; size_t h_2d_i_j_offset = h_2d_i_offset + base_j; @@ -8334,7 +8406,7 @@ static void ggml_compute_forward_gla_f32( } // Handle remaining elements, this will not be used. - for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) { + for (int64_t j = vec_count * gla_vector_size; j < head_size; j++) { size_t t_h_j_offset = t_h_offset + j; size_t h_2d_i_j_offset = h_2d_i_offset + j; float v_val = v[t_h_j_offset]; @@ -8443,83 +8515,126 @@ static void ggml_compute_forward_rwkv_wkv7_f32( int64_t h_stride_2d = head_size * head_size; #if defined(GGML_SIMD) - for (int64_t t = 0; t < T; t++) { - int64_t t_offset = t * t_stride; - int64_t state_offset = head_size * C * (t / (T / n_seqs)); - float * state_cur = state + state_offset; - float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset; + #if defined(__ARM_FEATURE_SVE) + // scalar Route to scalar implementation //TODO: Write SVE code + for (int64_t t = 0; t < T; t++) { + int64_t t_offset = t * t_stride; + int64_t state_offset = head_size * C * (t / (T / n_seqs)); + float * state_cur = state + state_offset; + float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset; - for (int64_t h = h_start; h < h_end; h++) { - int64_t h_offset = h * h_stride; - int64_t t_h_offset = t_offset + h_offset; - int64_t h_2d_offset = h * h_stride_2d; + for (int64_t h = h_start; h < h_end; h++) { + int64_t h_offset = h * h_stride; + int64_t t_h_offset = t_offset + h_offset; + int64_t h_2d_offset = h * h_stride_2d; - for (int64_t ii = 0; ii < head_size; ii++) { - int64_t t_h_i_offset = t_h_offset + ii; - int64_t h_2d_i_offset = h_2d_offset + ii * h_stride; + for (int64_t i = 0; i < head_size; i++) { + int64_t t_h_i_offset = t_h_offset + i; + int64_t h_2d_i_offset = h_2d_offset + i * h_stride; - GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]); + float v_val = v[t_h_i_offset]; - float sa = 0; - { - GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; - GGML_F32_VEC ax[GGML_F32_ARR]; - GGML_F32_VEC ay[GGML_F32_ARR]; - for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) { - for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) { - ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]); - ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]); - sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]); - } + float sa = 0, result = 0; + for (int64_t j = 0; j < head_size; j++) { + sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j]; } - GGML_F32_VEC_REDUCE(sa, sum); - } - GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa); + for (int64_t j = 0; j < head_size; j++) { + int64_t t_h_j_offset = t_h_offset + j; + int64_t h_2d_i_j_offset = h_2d_i_offset + j; - int64_t j = 0; - GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; - for (; j < head_size; j += GGML_F32_STEP) { - for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) { - int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR; - int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR; - - GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]); - GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]); - GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]); - GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]); - - k_vec = GGML_F32_VEC_MUL(v_vec, k_vec); - - GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]); - // kv + s * decay + sa * b - state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec); - state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec); - GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec); - - result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec); + float r_val = r[t_h_j_offset]; + float w_val = w[t_h_j_offset]; + float k_val = k[t_h_j_offset]; + float b_val = b[t_h_j_offset]; + float kv_val = v_val * k_val; + float prev_state_val = state_prev[h_2d_i_j_offset]; + state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val; + result += state_cur[h_2d_i_j_offset] * r_val; } - } - GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec); - - // There shouldn't be left-overs though. - for (; j < head_size; j++) { - int64_t t_h_j_offset = t_h_offset + j; - int64_t h_2d_i_j_offset = h_2d_i_offset + j; - - float r_val = r[t_h_j_offset]; - float w_val = w[t_h_j_offset]; - float k_val = k[t_h_j_offset]; - float b_val = b[t_h_j_offset]; - float kv_val = v[t_h_i_offset] * k_val; - - float prev_state_val = state_prev[h_2d_i_j_offset]; - state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val; - dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val; + dst_data[t_h_i_offset] = result; } } } - } + #else + for (int64_t t = 0; t < T; t++) { + int64_t t_offset = t * t_stride; + int64_t state_offset = head_size * C * (t / (T / n_seqs)); + float * state_cur = state + state_offset; + float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset; + + for (int64_t h = h_start; h < h_end; h++) { + int64_t h_offset = h * h_stride; + int64_t t_h_offset = t_offset + h_offset; + int64_t h_2d_offset = h * h_stride_2d; + + for (int64_t ii = 0; ii < head_size; ii++) { + int64_t t_h_i_offset = t_h_offset + ii; + int64_t h_2d_i_offset = h_2d_offset + ii * h_stride; + + GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]); + + float sa = 0; + { + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) { + for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) { + ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]); + ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]); + sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]); + } + } + GGML_F32_VEC_REDUCE(sa, sum); + } + + GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa); + + int64_t j = 0; + GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + for (; j < head_size; j += GGML_F32_STEP) { + for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) { + int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR; + int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR; + + GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]); + GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]); + GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]); + GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]); + + k_vec = GGML_F32_VEC_MUL(v_vec, k_vec); + + GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]); + // kv + s * decay + sa * b + state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec); + state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec); + GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec); + + result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec); + } + } + GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec); + + // There shouldn't be left-overs though. + for (; j < head_size; j++) { + int64_t t_h_j_offset = t_h_offset + j; + int64_t h_2d_i_j_offset = h_2d_i_offset + j; + + float r_val = r[t_h_j_offset]; + float w_val = w[t_h_j_offset]; + float k_val = k[t_h_j_offset]; + float b_val = b[t_h_j_offset]; + float kv_val = v[t_h_i_offset] * k_val; + + float prev_state_val = state_prev[h_2d_i_j_offset]; + state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val; + dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val; + } + } + } + } + #endif #else for (int64_t t = 0; t < T; t++) { int64_t t_offset = t * t_stride; diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index 45c31cf1f..2e3669c01 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -17,7 +17,123 @@ // number of elements to fit in a single register // -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) +#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_FMA) + +#define GGML_SIMD + +// F32 SVE +#define GGML_F32_EPR 8 +#define DEFAULT_PG svptrue_b32() + +#define GGML_F32xt svfloat32_t +#define GGML_F32xt_ZERO svdup_n_f32(0.0f) +#define GGML_F32xt_SET1(x) svdup_n_f32(x) +#define GGML_F32xt_LOAD_IMPL(pg, a, ...) svld1_f32(pg, a) +#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b) +#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, a, b, c) +#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b) +#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_MUL_IMPL(pg, a, b) svmul_f32_m(pg, a, b) +#define GGML_F32xt_MUL(...) GGML_F32xt_MUL_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_REDUCE_ONE_IMPL(pg, a) svaddv(pg, a) +#define GGML_F32xt_REDUCE_ONE(...) GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_REDUCE_IMPL(pg, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \ +{ \ + sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum2); \ + sum3 = svadd_f32_m(DEFAULT_PG, sum3, sum4); \ + sum5 = svadd_f32_m(DEFAULT_PG, sum5, sum6); \ + sum7 = svadd_f32_m(DEFAULT_PG, sum7, sum8); \ + sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum3); \ + sum5 = svadd_f32_m(DEFAULT_PG, sum5, sum7); \ + sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum5); \ + (res) = (ggml_float) GGML_F32xt_REDUCE_ONE(sum1); \ +} +#define GGML_F32xt_REDUCE(...) GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, __VA_ARGS__) + +#define GGML_F32_VEC GGML_F32xt +#define GGML_F32_VEC_ZERO GGML_F32xt_ZERO +#define GGML_F32_VEC_SET1 GGML_F32xt_SET1 +#define GGML_F32_VEC_LOAD GGML_F32xt_LOAD +#define GGML_F32_VEC_STORE GGML_F32xt_STORE +#define GGML_F32_VEC_FMA GGML_F32xt_FMA +#define GGML_F32_VEC_ADD GGML_F32xt_ADD +#define GGML_F32_VEC_MUL GGML_F32xt_MUL +#define GGML_F32_VEC_REDUCE GGML_F32xt_REDUCE + +// F16 NEON + +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + #define GGML_F16_STEP 32 + #define GGML_F16_EPR 8 + + #define GGML_F16x8 float16x8_t + #define GGML_F16x8_ZERO vdupq_n_f16(0.0f) + #define GGML_F16x8_SET1(x) vdupq_n_f16(x) + #define GGML_F16x8_LOAD(x) vld1q_f16((const __fp16 *)(x)) + #define GGML_F16x8_STORE vst1q_f16 + #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c) + #define GGML_F16x8_ADD vaddq_f16 + #define GGML_F16x8_MUL vmulq_f16 + #define GGML_F16x8_REDUCE(res, x) \ + do { \ + int offset = GGML_F16_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \ + } \ + const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 ((x)[0])); \ + const float32x4_t t1 = vcvt_f32_f16(vget_high_f16((x)[0])); \ + (res) = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \ + } while (0) + + #define GGML_F16_VEC GGML_F16x8 + #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO + #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 + #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((__fp16 *)(p), (r)[i]) + #define GGML_F16_VEC_FMA GGML_F16x8_FMA + #define GGML_F16_VEC_ADD GGML_F16x8_ADD + #define GGML_F16_VEC_MUL GGML_F16x8_MUL + #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE +#else + // if FP16 vector arithmetic is not supported, we use FP32 instead + // and take advantage of the vcvt_ functions to convert to/from FP16 + + #define GGML_F16_STEP 16 + #define GGML_F16_EPR 4 + + #define GGML_F32Cx4 float32x4_t + #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f) + #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x) + #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16((const __fp16 *)(x))) + #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y)) + #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c) + #define GGML_F32Cx4_ADD vaddq_f32 + #define GGML_F32Cx4_MUL vmulq_f32 + #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE + + #define GGML_F16_VEC GGML_F32Cx4 + #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO + #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 + #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((__fp16 *)(p), r[i]) + #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA + #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD + #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL + #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE +#endif + +#elif defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) #define GGML_SIMD diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index 02d406182..f7614568e 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -17,29 +17,98 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G #if defined(GGML_SIMD) float sumf = 0.0f; - const int np = (n & ~(GGML_F32_STEP - 1)); - GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + #if defined(__ARM_FEATURE_SVE) + const int sve_register_length = ggml_cpu_get_sve_cnt() * 8; + const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16 + const int ggml_f32_step = 8 * ggml_f32_epr; // choose 8 SVE registers - GGML_F32_VEC ax[GGML_F32_ARR]; - GGML_F32_VEC ay[GGML_F32_ARR]; + const int np = (n & ~(ggml_f32_step - 1)); + svfloat32_t sum1 = svdup_n_f32(0.0f); + svfloat32_t sum2 = svdup_n_f32(0.0f); + svfloat32_t sum3 = svdup_n_f32(0.0f); + svfloat32_t sum4 = svdup_n_f32(0.0f); + svfloat32_t sum5 = svdup_n_f32(0.0f); + svfloat32_t sum6 = svdup_n_f32(0.0f); + svfloat32_t sum7 = svdup_n_f32(0.0f); + svfloat32_t sum8 = svdup_n_f32(0.0f); + svfloat32_t ax1,ax2,ax3,ax4,ax5,ax6,ax7,ax8; + svfloat32_t ay1,ay2,ay3,ay4,ay5,ay6,ay7,ay8; + for (int i = 0; i < np; i += ggml_f32_step) { + ax1 = GGML_F32_VEC_LOAD(x + i); + ay1 = GGML_F32_VEC_LOAD(y + i); + sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1); - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr); + ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); + sum2 = GGML_F32_VEC_FMA(ax2, ay2, sum2); - sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); + ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr); + ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr); + sum3 = GGML_F32_VEC_FMA(ax3, ay3, sum3); + + ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr); + ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr); + sum4 = GGML_F32_VEC_FMA(ax4, ay4, sum4); + + ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr); + ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr); + sum5 = GGML_F32_VEC_FMA(ax5, ay5, sum5); + + ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr); + ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr); + sum6 = GGML_F32_VEC_FMA(ax6, ay6, sum6); + + ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr); + ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr); + sum7 = GGML_F32_VEC_FMA(ax7, ay7, sum7); + + ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr); + ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr); + sum8 = GGML_F32_VEC_FMA(ax8, ay8, sum8); } - } + // leftovers + // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop + const int np2 = (n & ~(ggml_f32_epr - 1)); + for (int i = np; i < np2; i += ggml_f32_epr) { + ax1 = GGML_F32_VEC_LOAD(x + i); + ay1 = GGML_F32_VEC_LOAD(y + i); + sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1); + } + // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only + if (np2 < n) { + svbool_t pg = svwhilelt_b32(np2, n); + ax1 = svld1_f32(pg, x + np2); + ay1 = svld1_f32(pg, y + np2); + sum1 = svmad_f32_m(pg, ax1, ay1, sum1); + } + // reduce sum1,sum2 to sum1 + GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8); + #else + const int np = (n & ~(GGML_F32_STEP - 1)); - // reduce sum0..sum3 to sum0 - GGML_F32_VEC_REDUCE(sumf, sum); + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; - // leftovers - for (int i = np; i < n; ++i) { - sumf += x[i]*y[i]; - } + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + + sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F32_VEC_REDUCE(sumf, sum); + + // leftovers + for (int i = np; i < n; ++i) { + sumf += x[i]*y[i]; + } + #endif #else // scalar ggml_float sumf = 0.0; diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index c77349ebe..09dbade21 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -5,6 +5,7 @@ #include "ggml-impl.h" #include "simd-mappings.h" #include "ggml.h" +#include "ggml-cpu.h" #if defined(GGML_USE_ACCELERATE) #include @@ -148,27 +149,108 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const float * GGML_RESTRICT x, const float v) { #if defined(GGML_SIMD) - const int np = (n & ~(GGML_F32_STEP - 1)); + #if defined(__ARM_FEATURE_SVE) - GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + const int sve_register_length = ggml_cpu_get_sve_cnt() * 8; + const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16 + const int ggml_f32_step = 8 * ggml_f32_epr; // choose 8 SVE registers + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); - GGML_F32_VEC ax[GGML_F32_ARR]; - GGML_F32_VEC ay[GGML_F32_ARR]; + const int np = (n & ~(ggml_f32_step - 1)); + svfloat32_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; + svfloat32_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; + for (int i = 0; i < np; i += ggml_f32_step) { - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx); + ax1 = GGML_F32_VEC_LOAD(x + i); + ay1 = GGML_F32_VEC_LOAD(y + i); + ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1); - GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + GGML_F32_VEC_STORE(y + i, ay1); + + ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr); + ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); + ay2 = GGML_F32_VEC_FMA(ax2, vx, ay2); + + GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2); + + ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr); + ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr); + ay3 = GGML_F32_VEC_FMA(ax3, vx, ay3); + + GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3); + + ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr); + ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr); + ay4 = GGML_F32_VEC_FMA(ax4, vx, ay4); + + GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4); + + ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr); + ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr); + ay5 = GGML_F32_VEC_FMA(ax5, vx, ay5); + + GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5); + + ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr); + ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr); + ay6 = GGML_F32_VEC_FMA(ax6, vx, ay6); + + GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6); + + ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr); + ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr); + ay7 = GGML_F32_VEC_FMA(ax7, vx, ay7); + + GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7); + + ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr); + ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr); + ay8 = GGML_F32_VEC_FMA(ax8, vx, ay8); + + GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8); } - } + // leftovers + // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop + const int np2 = (n & ~(ggml_f32_epr - 1)); + for (int i = np; i < np2; i += ggml_f32_epr) { + ax1 = GGML_F32_VEC_LOAD(x + i); + ay1 = GGML_F32_VEC_LOAD(y + i); + ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1); - // leftovers - for (int i = np; i < n; ++i) { - y[i] += x[i]*v; - } + GGML_F32_VEC_STORE(y + i, ay1); + } + // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only + if (np2 < n) { + svbool_t pg =svwhilelt_b32(np2, n); + ax1 = svld1_f32(pg, x + np2); + ay1 = svld1_f32(pg, y + np2); + ay1 = svmad_f32_m(pg, ax1, vx, ay1); + + svst1_f32(pg, y + np2, ay1); + } + #else + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx); + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] += x[i]*v; + } + #endif #else // scalar for (int i = 0; i < n; ++i) { @@ -220,36 +302,45 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int } #if defined(GGML_SIMD) - const int np = (n & ~(GGML_F32_STEP - 1)); - - GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL]; - - for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { - vx[k] = GGML_F32_VEC_SET1(v[k][0]); - } - - GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR]; - GGML_F32_VEC ay[GGML_F32_ARR]; - - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - - for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { - ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]); + #if defined(__ARM_FEATURE_SVE) + // scalar Route to scalar implementation //TODO: Write SVE code + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + for (int i = 0; i < n; ++i) { + y[i] += x[k][i]*v[k][0]; } - - GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); } - } + #else + const int np = (n & ~(GGML_F32_STEP - 1)); - // leftovers - for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { - for (int i = np; i < n; ++i) { - y[i] += x[k][i]*v[k][0]; + GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL]; + + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + vx[k] = GGML_F32_VEC_SET1(v[k][0]); } - } + + GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]); + } + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + for (int i = np; i < n; ++i) { + y[i] += x[k][i]*v[k][0]; + } + } + #endif #else // scalar for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { @@ -265,25 +356,53 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { #if defined(GGML_USE_ACCELERATE) vDSP_vsmul(y, 1, &v, y, 1, n); #elif defined(GGML_SIMD) - const int np = (n & ~(GGML_F32_STEP - 1)); + #if defined(__ARM_FEATURE_SVE) + const int sve_register_length = ggml_cpu_get_sve_cnt() * 8; + const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16 + const int ggml_f32_step = 2 * ggml_f32_epr; - GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + const int np = (n & ~(ggml_f32_step - 1)); + svfloat32_t ay1; + svfloat32_t ay2; + for (int i = 0; i < np; i += ggml_f32_step) { + ay1 = GGML_F32_VEC_LOAD(y + i); + ay1 = GGML_F32_VEC_MUL(ay1, vx); + GGML_F32_VEC_STORE(y + i, ay1); - GGML_F32_VEC ay[GGML_F32_ARR]; - - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_MUL(ay[j], vx); - - GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); + ay2 = GGML_F32_VEC_MUL(ay2, vx); + GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2); } - } + // leftovers + // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only + if (np < n) { + svbool_t pg = svwhilelt_b32(np, n); + ay1 = svld1_f32(pg, y + np); + ay1 = svmul_f32_m(pg, ay1, vx); + svst1_f32(pg, y + np, ay1); + } + #else + const int np = (n & ~(GGML_F32_STEP - 1)); - // leftovers - for (int i = np; i < n; ++i) { - y[i] *= v; - } + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_MUL(ay[j], vx); + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] *= v; + } + #endif #else // scalar for (int i = 0; i < n; ++i) { @@ -528,6 +647,42 @@ inline static ggml_fp16_t ggml_silu_f16(ggml_fp16_t x) { #error "ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461" #endif +/* Below function was borrowed from the GitHub repository: +https://github.com/openvinotoolkit/openvino/blob/master/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp */ +#if defined(__ARM_FEATURE_SVE) && defined(__aarch64__) + inline static svfloat32_t exp_ps_sve(svbool_t pg, svfloat32_t src) { + // Constants + const svfloat32_t log2_e = svdup_n_f32(1.4426950409f); + const svfloat32_t ln2 = svdup_n_f32(0.6931473921f); + const svfloat32_t half_ln2_sq = svdup_n_f32(0.2413862043f); + const svuint32_t not_mask17 = svdup_n_u32(~((1u << 17) - 1)); + const svfloat32_t one = svdup_n_f32(1.0f); + const svfloat32_t inactive1 = svdup_n_f32(0.0f); + const svint32_t inactive2 = svdup_n_s32(0); + + // Algorithm starts here + svfloat32_t t0 = svmul_f32_m(pg, src, log2_e); // y = x * log2(e) + svfloat32_t t1 = svrintm_f32_m(inactive1, pg, t0); // rount to int (float) + svint32_t t2 = svcvt_s32_f32_m(inactive2, pg, t1); // n + + t1 = svsub_f32_m(pg, t0, t1); // a = y - floor(y) + t1 = svadd_f32_m(pg, t1, one); // b = a + 1 + + svuint32_t t3 = svlsr_n_u32_m(pg, svreinterpret_u32_f32(t1), 17); // v = b >> 17 (u32) + svfloat32_t t4 = svexpa_f32(t3); // c = fexpa(v) + t4 = svscale_f32_m(pg, t4, t2); // fexpa(v) * 2^(n) + + // and_(t2.d, t1.d, not_mask17.d) + svfloat32_t t5 = svreinterpret_f32_u32(svand_u32_m(pg, svreinterpret_u32_f32(t1), not_mask17)); + t5 = svsub_f32_m(pg, t1, t5); // z + t0 = svmla_f32_m(pg, ln2, t5, half_ln2_sq); // ln2 + half_ln2_sq * z + t0 = svmla_f32_m(pg, one, t5, t0); // 1 + (ln2 * z) + (half_ln2_sq * z * z) + t0 = svmul_f32_m(pg, t0, t4); // Final result + + return t0; + } +#endif + #if defined(__ARM_NEON) && defined(__aarch64__) // adapted from arm limited optimized routine diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 64fb4ff4c..e1ce1d4cd 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -168,7 +168,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in #define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str) -#if !defined(GGML_USE_HIP) +#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) static const char * cu_get_error_str(CUresult err) { const char * err_str; cuGetErrorString(err, &err_str); @@ -635,6 +635,7 @@ struct ggml_cuda_device_info { int nsm; // number of streaming multiprocessors size_t smpb; // max. shared memory per block size_t smpbo; // max. shared memory per block (with opt-in) + bool integrated; // Device is integrated as opposed to discrete bool vmm; // virtual memory support size_t vmm_granularity; // granularity of virtual memory size_t total_vram; diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index a4fbd8236..cfab2b5eb 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -623,8 +623,8 @@ static __global__ void flash_attn_combine_results( __builtin_assume(tid < D); extern __shared__ float2 meta[]; - if (tid < 2*parallel_blocks) { - ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid]; + for (int i = tid; i < 2*parallel_blocks; i += D) { + ((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i]; } __syncthreads(); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 7120053b6..925f39e89 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1246,7 +1246,7 @@ static __global__ void flash_attn_ext_f16( NO_DEVICE_CODE; return; } -#endif __CUDA_ARCH__ == GGML_CUDA_CC_TURING +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV"); diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 798a59b27..35e649cb3 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -212,6 +212,7 @@ static __global__ void flash_attn_vec_ext_f16( } } if (__all_sync(0xFFFFFFFF, skip)) { + __syncthreads(); continue; } #endif // GGML_USE_HIP diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 49c592ea5..953967917 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -217,6 +217,7 @@ static __global__ void flash_attn_vec_ext_f32( } } if (__all_sync(0xFFFFFFFF, skip)) { + __syncthreads(); continue; } #endif // GGML_USE_HIP diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 02dc8c12d..2a6f7f108 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -243,10 +243,10 @@ static ggml_cuda_device_info ggml_cuda_init() { info.default_tensor_split[id] = total_vram; total_vram += prop.totalGlobalMem; - - info.devices[id].nsm = prop.multiProcessorCount; - info.devices[id].smpb = prop.sharedMemPerBlock; - info.devices[id].warp_size = prop.warpSize; + info.devices[id].integrated = prop.integrated; + info.devices[id].nsm = prop.multiProcessorCount; + info.devices[id].smpb = prop.sharedMemPerBlock; + info.devices[id].warp_size = prop.warpSize; #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) info.devices[id].smpbo = prop.sharedMemPerBlock; @@ -1065,6 +1065,10 @@ static const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_ GGML_UNUSED(buft); } +static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_cuda_host_buffer_type_name; +} + static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { CUDA_CHECK(cudaFreeHost(buffer->context)); } @@ -2192,6 +2196,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_SILU: ggml_cuda_op_silu(ctx, dst); break; + case GGML_UNARY_OP_GELU_ERF: + ggml_cuda_op_gelu_erf(ctx, dst); + break; case GGML_UNARY_OP_GELU_QUICK: ggml_cuda_op_gelu_quick(ctx, dst); break; @@ -2638,6 +2645,8 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { + // flag used to determine whether it is an integrated_gpu + const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; while (!graph_evaluated_or_captured) { // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. @@ -2656,7 +2665,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx if (node->src[j] != nullptr) { assert(node->src[j]->buffer); assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || - ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft)); + ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft) || (integrated && ggml_backend_buft_is_cuda_host(node->src[j]->buffer->buft))); } } #endif @@ -2977,6 +2986,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSWISH: + case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_EXP: @@ -2990,9 +3000,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g { struct ggml_tensor * a = op->src[0]; struct ggml_tensor * b = op->src[1]; - // for small weight matrices the active device can end up without any rows, don't use row split in those cases - // this avoids some edge cases (and the performance would not be good anyways) if (a->buffer && ggml_backend_buft_is_cuda_split(a->buffer->buft)) { + if (a->ne[2] > 1 || a->ne[3] > 1) { + return false; + } + // for small weight matrices the active device can end up without any rows, don't use row split in those cases + // this avoids some edge cases (and the performance would not be good anyways) ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) a->buffer->buft->context; int64_t row_low; int64_t row_high; @@ -3259,7 +3272,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g } static bool ggml_backend_cuda_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { - return (ggml_backend_buft_is_cuda(buft) || ggml_backend_buft_is_cuda_split(buft)) && buft->device == dev; + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context; + const bool integrated = ggml_cuda_info().devices[dev_ctx->device].integrated; + return (((ggml_backend_buft_is_cuda(buft) || ggml_backend_buft_is_cuda_split(buft)) && buft->device == dev) || (integrated && ggml_backend_buft_is_cuda_host(buft))); } static int64_t get_op_batch_size(const ggml_tensor * op) { diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index ec5773e01..2c0375fbe 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -23,6 +23,12 @@ static __device__ __forceinline__ float op_gelu(float x) { return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } +static __device__ __forceinline__ float op_gelu_erf(float x) { + const float SQRT_2_INV = 0.70710678118654752440084436210484f; + + return 0.5f*x*(1.0f + erff(x*SQRT_2_INV)); +} + static __device__ __forceinline__ float op_gelu_quick(float x) { const float GELU_QUICK_COEF = -1.702f; @@ -134,6 +140,10 @@ void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } +void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 940a1feed..6686fc17e 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -30,6 +30,8 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index a19cfb14e..6dc5ce0d9 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -32,6 +32,8 @@ extern "C" { #endif +void ggml_print_backtrace(void); + #ifndef MIN # define MIN(a, b) ((a) < (b) ? (a) : (b)) #endif @@ -386,7 +388,7 @@ GGML_API void ggml_aligned_free(void * ptr, size_t size); return r; } -#elif defined(__riscv) && defined(GGML_RV_ZFH) +#elif defined(__riscv) && defined(__riscv_zfhmin) static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { float f; diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 352deb321..9f930c70b 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -55,14 +55,17 @@ endfunction() set(GGML_OPENCL_KERNELS add + argsort clamp cpy cvt diag_mask_inf + div gelu gemv_noshuffle_general gemv_noshuffle get_rows + group_norm im2col_f32 im2col_f16 mul_mat_Ab_Bi_8x4 @@ -83,11 +86,14 @@ set(GGML_OPENCL_KERNELS rms_norm rope scale + sigmoid silu softmax_4_f32 softmax_4_f16 softmax_f32 softmax_f16 + sub + sum_rows transpose ) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 586946048..5dbe97ab2 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #undef MIN #undef MAX @@ -74,6 +75,7 @@ struct ggml_cl_version { cl_uint minor = 0; }; + struct ggml_cl_compiler_version { ADRENO_CL_COMPILER_TYPE type; int major = -1; @@ -91,6 +93,14 @@ struct ggml_cl_compiler_version { } }; +static size_t align_to(size_t value, size_t to_alignment) { + GGML_ASSERT(to_alignment && "Invalid alignment (must be non-zero)"); + GGML_ASSERT((to_alignment & (to_alignment - 1)) == 0 && "to_alignment must be power-of-two"); + + return ((value + to_alignment - 1) / to_alignment) * to_alignment; +} + + // Parses a version string of form "XX.YY ". On an error returns ggml_cl_version with all zeroes. static ggml_cl_version parse_cl_version(std::string_view str) { size_t major_str_begin = 0; @@ -221,13 +231,25 @@ static ggml_cl_compiler_version get_adreno_cl_compiler_version(const char *drive return { type, major, minor, patch }; } +struct ggml_backend_opencl_context; + // backend device context struct ggml_backend_opencl_device_context { cl_platform_id platform; std::string platform_name; - cl_device_id device; - std::string device_name; + cl_device_id device; + std::string device_name; + cl_device_type device_type; + std::string device_version; + + // Initialized by ggml_cl2_init(). + ggml_backend_opencl_context * backend_ctx = nullptr; + + // Initialized by ggml_backend_opencl_device_get_buffer_type() + ggml_backend_buffer_type buffer_type; + + cl_context context = nullptr; }; // backend context @@ -248,6 +270,8 @@ struct ggml_backend_opencl_context { int adreno_wave_size; + cl_bool non_uniform_workgroups; + cl_context context; cl_command_queue queue; @@ -275,27 +299,37 @@ struct ggml_backend_opencl_context { cl_program program_mul_mv_f16_f32; cl_program program_mul_mv_f32_f32; cl_program program_mul; + cl_program program_div; + cl_program program_sub; cl_program program_norm; cl_program program_relu; cl_program program_rms_norm; + cl_program program_group_norm; cl_program program_rope; cl_program program_scale; cl_program program_silu; + cl_program program_sigmoid; cl_program program_softmax_f32; cl_program program_softmax_f16; cl_program program_softmax_4_f32; cl_program program_softmax_4_f16; + cl_program program_argsort_f32_i32; + cl_program program_sum_rows_f32; cl_kernel kernel_add, kernel_add_row; cl_kernel kernel_mul, kernel_mul_row; + cl_kernel kernel_div, kernel_div_row; + cl_kernel kernel_sub, kernel_sub_row; cl_kernel kernel_scale; cl_kernel kernel_silu, kernel_silu_4; cl_kernel kernel_gelu, kernel_gelu_4; cl_kernel kernel_gelu_quick, kernel_gelu_quick_4; cl_kernel kernel_relu; + cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16; cl_kernel kernel_clamp; cl_kernel kernel_norm; cl_kernel kernel_rms_norm; + cl_kernel kernel_group_norm; cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8; cl_kernel kernel_soft_max, kernel_soft_max_4; cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16; @@ -315,6 +349,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; cl_kernel kernel_mul_mv_q6_K_f32; cl_kernel kernel_im2col_f32, kernel_im2col_f16; + cl_kernel kernel_argsort_f32_i32; + cl_kernel kernel_sum_rows_f32; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS // Transpose kernels @@ -344,15 +380,8 @@ struct ggml_backend_opencl_context { #endif // GGML_OPENCL_USE_ADRENO_KERNELS }; -static ggml_backend_device g_ggml_backend_opencl_device; -static ggml_backend_opencl_device_context g_ggml_ctx_dev_main { - /*.platform =*/ nullptr, - /*.platform_nane =*/ "", - /*.device =*/ nullptr, - /*.device_name =*/ "", -}; - -static int ggml_backend_opencl_n_devices = 0; +// All registered devices with a default device in the front. +static std::vector g_ggml_backend_opencl_devices; // Profiling #ifdef GGML_OPENCL_PROFILING @@ -969,6 +998,105 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // argsort + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "argsort.cl.h" + }; +#else + const std::string kernel_src = read_file("argsort.cl"); +#endif + backend_ctx->program_argsort_f32_i32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_argsort_f32_i32 = clCreateKernel(backend_ctx->program_argsort_f32_i32, "kernel_argsort_f32_i32", &err), err)); + GGML_LOG_CONT("."); + } + + // div + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "div.cl.h" + }; +#else + const std::string kernel_src = read_file("div.cl"); +#endif + backend_ctx->program_div = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_div = clCreateKernel(backend_ctx->program_div, "kernel_div", &err), err)); + CL_CHECK((backend_ctx->kernel_div_row = clCreateKernel(backend_ctx->program_div, "kernel_div_row", &err), err)); + GGML_LOG_CONT("."); + } + + // sub + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "sub.cl.h" + }; +#else + const std::string kernel_src = read_file("sub.cl"); +#endif + backend_ctx->program_sub = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_sub = clCreateKernel(backend_ctx->program_sub, "kernel_sub", &err), err)); + CL_CHECK((backend_ctx->kernel_sub_row = clCreateKernel(backend_ctx->program_sub, "kernel_sub_row", &err), err)); + GGML_LOG_CONT("."); + } + + // sum_rows + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "sum_rows.cl.h" + }; +#else + const std::string kernel_src = read_file("sum_rows.cl"); +#endif + backend_ctx->program_sum_rows_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_sum_rows_f32 = clCreateKernel(backend_ctx->program_sum_rows_f32, "kernel_sum_rows_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // sigmoid + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "sigmoid.cl.h" + }; +#else + const std::string kernel_src = read_file("sigmoid.cl"); +#endif + backend_ctx->program_sigmoid = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_sigmoid_f32 = clCreateKernel(backend_ctx->program_sigmoid, "kernel_sigmoid_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_sigmoid_f16 = clCreateKernel(backend_ctx->program_sigmoid, "kernel_sigmoid_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // group_norm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "group_norm.cl.h" + }; +#else + const std::string kernel_src = read_file("group_norm.cl"); +#endif + backend_ctx->program_group_norm = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_group_norm = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm", &err), err)); + GGML_LOG_CONT("."); + } + // Adreno kernels #ifdef GGML_OPENCL_USE_ADRENO_KERNELS // transpose @@ -1107,25 +1235,19 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("\n"); } -static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { - static bool initialized = false; - static ggml_backend_opencl_context *backend_ctx = nullptr; +// XXX static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { +// XXX static bool initialized = false; +// XXX static ggml_backend_opencl_context *backend_ctx = nullptr; - if (initialized) { - return backend_ctx; - } +static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev); - ggml_backend_opencl_device_context *dev_ctx = (ggml_backend_opencl_device_context *)dev->context; - GGML_ASSERT(dev_ctx); - GGML_ASSERT(dev_ctx->platform == nullptr); - GGML_ASSERT(dev_ctx->device == nullptr); - GGML_ASSERT(backend_ctx == nullptr); +namespace /* anonymous */ { +extern struct ggml_backend_device_i ggml_backend_opencl_device_i; +} - initialized = true; - backend_ctx = new ggml_backend_opencl_context(); - backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN; - - cl_int err; +// Look for available and suitable devices. +static std::vector ggml_opencl_probe_devices(ggml_backend_reg * reg) { + std::vector found_devices; #ifdef GGML_OPENCL_PROFILING GGML_LOG_INFO("ggml_opencl: OpenCL profiling enabled\n"); @@ -1158,11 +1280,12 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { struct cl_device devices[NDEV]; unsigned n_devices = 0; struct cl_device * default_device = NULL; + unsigned default_platform_number = 0; cl_platform_id platform_ids[NPLAT]; if (clGetPlatformIDs(NPLAT, platform_ids, &n_platforms) != CL_SUCCESS) { GGML_LOG_ERROR("ggml_opencl: plaform IDs not available.\n"); - return backend_ctx; + return found_devices; } for (unsigned i = 0; i < n_platforms; i++) { @@ -1197,19 +1320,22 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { } if (default_device == NULL && p->default_device != NULL) { - default_device = p->default_device; + default_device = p->default_device; + default_platform_number = i; } } if (n_devices == 0) { GGML_LOG_ERROR("ggml_opencl: could find any OpenCL devices.\n"); - return backend_ctx; + return found_devices; } - char * user_platform_string = getenv("GGML_OPENCL_PLATFORM"); - char * user_device_string = getenv("GGML_OPENCL_DEVICE"); - int user_platform_number = -1; - int user_device_number = -1; + char * user_platform_string = getenv("GGML_OPENCL_PLATFORM"); + char * user_device_string = getenv("GGML_OPENCL_DEVICE"); + int user_platform_number = -1; + int user_device_number = -1; + cl_device * candidate_devices = nullptr; + unsigned n_candidate_devices = 0; unsigned n; if (user_platform_string != NULL && sscanf(user_platform_string, " %u", &n) == 1 && n < n_platforms) { @@ -1224,12 +1350,11 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { GGML_LOG_ERROR("ggml_opencl: invalid device number %d\n", user_device_number); exit(1); } - default_device = &platform->devices[user_device_number]; + default_device = &platform->devices[user_device_number]; + candidate_devices = platform->devices; + n_candidate_devices = platform->n_devices; } else { - - struct cl_device * selected_devices = devices; - unsigned n_selected_devices = n_devices; - + // Choose a platform by matching a substring. if (user_platform_number == -1 && user_platform_string != NULL && user_platform_string[0] != 0) { for (unsigned i = 0; i < n_platforms; i++) { struct cl_platform * p = &platforms[i]; @@ -1244,20 +1369,20 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { exit(1); } } - if (user_platform_number != -1) { - struct cl_platform * p = &platforms[user_platform_number]; - selected_devices = p->devices; - n_selected_devices = p->n_devices; - default_device = p->default_device; - if (n_selected_devices == 0) { - GGML_LOG_ERROR("ggml_opencl: selected platform '%s' does not have any devices.\n", p->name); - exit(1); - } + + int platform_idx = user_platform_number != -1 ? user_platform_number : default_platform_number; + struct cl_platform * p = &platforms[platform_idx]; + candidate_devices = p->devices; + n_candidate_devices = p->n_devices; + default_device = p->default_device; + if (n_candidate_devices == 0) { + GGML_LOG_ERROR("ggml_opencl: selected platform '%s' does not have any devices.\n", p->name); + exit(1); } if (user_device_number == -1 && user_device_string != NULL && user_device_string[0] != 0) { - for (unsigned i = 0; i < n_selected_devices; i++) { - struct cl_device * d = &selected_devices[i]; + for (unsigned i = 0; i < n_candidate_devices; i++) { + struct cl_device * d = &candidate_devices[i]; if (strstr(d->name, user_device_string) != NULL) { user_device_number = d->number; break; @@ -1269,71 +1394,145 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { } } if (user_device_number != -1) { - selected_devices = &devices[user_device_number]; - n_selected_devices = 1; - default_device = &selected_devices[0]; + candidate_devices = &devices[user_device_number]; + n_candidate_devices = 1; + default_device = &candidate_devices[0]; } - GGML_ASSERT(n_selected_devices > 0); + GGML_ASSERT(n_candidate_devices > 0); if (default_device == NULL) { - default_device = &selected_devices[0]; + default_device = &candidate_devices[0]; } } - GGML_LOG_INFO("ggml_opencl: selecting platform: '%s'\n", default_device->platform->name); - GGML_LOG_INFO("ggml_opencl: selecting device: '%s (%s)'\n", default_device->name, default_device->version); - if (default_device->type != CL_DEVICE_TYPE_GPU) { - GGML_LOG_WARN("ggml_opencl: warning, not a GPU: '%s'.\n", default_device->name); + GGML_ASSERT(n_candidate_devices != 0 && candidate_devices); + + // Put the default device in front. + for (unsigned i = 1; i < n_candidate_devices; i++) { + if (&candidate_devices[i] == default_device) { + std::swap(candidate_devices[0], candidate_devices[i]); + default_device = &candidate_devices[0]; + break; + } } - dev_ctx->platform = default_device->platform->id; - dev_ctx->device = default_device->id; - backend_ctx->device = default_device->id; + GGML_LOG_INFO("ggml_opencl: selected platform: '%s'\n", default_device->platform->name); - if (strstr(default_device->name, "Adreno") || - strstr(default_device->name, "Qualcomm") || - strstr(default_device->version, "Adreno")) { + std::vector device_ids; + for (auto dev = candidate_devices, dev_end = candidate_devices + n_candidate_devices; dev != dev_end; dev++) { + device_ids.push_back(dev->id); + } + + cl_int err; + cl_context shared_context; + cl_context_properties properties[] = { (intptr_t) CL_CONTEXT_PLATFORM, (intptr_t) default_device->platform->id, 0 }; + + CL_CHECK( + (shared_context = clCreateContext(properties, device_ids.size(), device_ids.data(), NULL, NULL, &err), err)); + + for (auto dev = candidate_devices, dev_end = candidate_devices + n_candidate_devices; dev != dev_end; dev++) { + GGML_LOG_INFO("\nggml_opencl: device: '%s (%s)'\n", dev->name, dev->version); + + auto dev_ctx = std::unique_ptr(new ggml_backend_opencl_device_context{ + /*.platform =*/dev->platform->id, + /*.platform_nane =*/dev->platform->name, + /*.device =*/dev->id, + /*.device_name =*/dev->name, + /*.device_type =*/dev->type, + /*.device_version =*/dev->version, + /*.backend_ctx =*/nullptr, + /*.buffer_type =*/{}, + /*.context =*/shared_context, + }); + + found_devices.push_back(ggml_backend_device{ + /* .iface = */ ggml_backend_opencl_device_i, + /* .reg = */ reg, + /* .context = */ dev_ctx.get(), + }); + + if (!ggml_cl2_init(&found_devices.back())) { + found_devices.pop_back(); + GGML_LOG_INFO("ggml_opencl: drop unsupported device.\n"); + continue; + } + + dev_ctx.release(); + } + + if (found_devices.size()) { + auto * dev_ctx = static_cast(found_devices.front().context); + GGML_LOG_INFO("ggml_opencl: default device: '%s (%s)'\n", dev_ctx->device_name.c_str(), + dev_ctx->device_version.c_str()); + + if (dev_ctx->device_type != CL_DEVICE_TYPE_GPU) { + GGML_LOG_WARN("ggml_opencl: warning, the default device is not a GPU: '%s'.\n", + dev_ctx->device_name.c_str()); + } + } + + return found_devices; +} + +// Initialize device if it is supported (returns nullptr if it is not). +static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { + GGML_ASSERT(dev); + GGML_ASSERT(dev->context); + + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) dev->context; + GGML_ASSERT(dev_ctx->platform); + GGML_ASSERT(dev_ctx->device); + + if (dev_ctx->backend_ctx) { + return dev_ctx->backend_ctx; + } + + auto backend_ctx = std::make_unique(); + backend_ctx->device = dev_ctx->device; + backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN; + + if (strstr(dev_ctx->device_name.c_str(), "Adreno") || + strstr(dev_ctx->device_name.c_str(), "Qualcomm") || + strstr(dev_ctx->device_version.c_str(), "Adreno")) { backend_ctx->gpu_family = GPU_FAMILY::ADRENO; // Usually device version contains the detailed device name - backend_ctx->adreno_gen = get_adreno_gpu_gen(default_device->version); + backend_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_version.c_str()); if (backend_ctx->adreno_gen == ADRENO_GPU_GEN::ADRENO_UNKNOWN) { - backend_ctx->adreno_gen = get_adreno_gpu_gen(default_device->name); + backend_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_name.c_str()); } // Use wave size of 64 for all Adreno GPUs. backend_ctx->adreno_wave_size = 64; - } else if (strstr(default_device->name, "Intel")) { + } else if (strstr(dev_ctx->device_name.c_str(), "Intel")) { backend_ctx->gpu_family = GPU_FAMILY::INTEL; } else { - GGML_LOG_ERROR("Unsupported GPU: %s\n", default_device->name); + GGML_LOG_ERROR("Unsupported GPU: %s\n", dev_ctx->device_name.c_str()); backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN; - return backend_ctx; + return nullptr; } #ifdef GGML_OPENCL_USE_ADRENO_KERNELS if (backend_ctx->gpu_family != GPU_FAMILY::ADRENO) { GGML_LOG_ERROR("ggml_opencl: Adreno-specific kernels should not be enabled for non-Adreno GPUs; " "run on an Adreno GPU or recompile with CMake option `-DGGML_OPENCL_USE_ADRENO_KERNELS=OFF`\n"); - return backend_ctx; + return nullptr; } #endif // Populate backend device name - dev_ctx->platform_name = default_device->platform->name; - dev_ctx->device_name = default_device->name; - backend_ctx->device_name = default_device->name; + backend_ctx->device_name = dev_ctx->device_name; // A local ref of cl_device_id for convenience cl_device_id device = backend_ctx->device; - ggml_cl_version platform_version = get_opencl_platform_version(default_device->platform->id); + ggml_cl_version platform_version = get_opencl_platform_version(dev_ctx->platform); // Check device OpenCL version, OpenCL 2.0 or above is required ggml_cl_version opencl_c_version = get_opencl_c_version(platform_version, device); if (opencl_c_version.major < 2) { GGML_LOG_ERROR("ggml_opencl: OpenCL 2.0 or above is required\n"); - return backend_ctx; + return nullptr; } // Check driver version @@ -1364,7 +1563,7 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { // fp16 is required if (!backend_ctx->fp16_support) { GGML_LOG_ERROR("ggml_opencl: device does not support FP16\n"); - return backend_ctx; + return nullptr; } // If OpenCL 3.0 is supported, then check for cl_khr_subgroups, which becomes @@ -1373,7 +1572,7 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { strstr(ext_buffer, "cl_intel_subgroups") == NULL) { GGML_LOG_ERROR("ggml_opencl: device does not support subgroups (cl_khr_subgroups or cl_intel_subgroups) " "(note that subgroups is an optional feature in OpenCL 3.0)\n"); - return backend_ctx; + return nullptr; } cl_uint base_align_in_bits; @@ -1397,6 +1596,15 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { GGML_LOG_INFO("ggml_opencl: SVM atomics support: %s\n", svm_caps & CL_DEVICE_SVM_ATOMICS ? "true" : "false"); + if (opencl_c_version.major >= 3) { + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_NON_UNIFORM_WORK_GROUP_SUPPORT, sizeof(cl_bool), + &backend_ctx->non_uniform_workgroups, 0)); + } else { + GGML_ASSERT(opencl_c_version.major == 2); + // Non-uniform workgroup sizes is mandatory feature in v2.x. + backend_ctx->non_uniform_workgroups = true; + } + // Print out configurations #ifdef GGML_OPENCL_SOA_Q GGML_LOG_INFO("ggml_opencl: flattening quantized weights representation as struct of arrays (GGML_OPENCL_SOA_Q)\n"); @@ -1406,14 +1614,10 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { GGML_LOG_INFO("ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\n"); #endif // GGML_OPENCL_USE_ADRENO_KERNELS - cl_context_properties properties[] = { - (intptr_t)CL_CONTEXT_PLATFORM, (intptr_t)dev_ctx->platform, 0 - }; - - CL_CHECK((backend_ctx->context = clCreateContext(properties, 1, &device, NULL, NULL, &err), err)); + cl_int err; // A local ref of cl_context for convenience - cl_context context = backend_ctx->context; + cl_context context = backend_ctx->context = dev_ctx->context; //CL_CHECK((queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err), // (err != CL_INVALID_QUEUE_PROPERTIES && err != CL_INVALID_VALUE ? err : @@ -1426,7 +1630,7 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { CL_CHECK((backend_ctx->queue = clCreateCommandQueue(context, device, command_queue_props, &err), err)); // Load kernels - load_cl_kernels(backend_ctx, opencl_c_version); + load_cl_kernels(backend_ctx.get(), opencl_c_version); #ifdef GGML_OPENCL_USE_ADRENO_KERNELS // Allocate intermediate buffers and images @@ -1456,10 +1660,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { CL_CHECK((backend_ctx->B_d_max = clCreateBuffer(context, 0, max_B_d_bytes, NULL, &err), err)); #endif // GGML_OPENCL_USE_ADRENO_KERNELS - // For now we support a single devices - ggml_backend_opencl_n_devices = 1; - - return backend_ctx; + dev_ctx->backend_ctx = backend_ctx.release(); + return dev_ctx->backend_ctx; } static void ggml_cl2_free(void) { @@ -1664,10 +1866,46 @@ static void ggml_backend_opencl_synchronize(ggml_backend_t backend) { GGML_UNUSED(backend); } +// Syncronizes the 'backend_ctx's device with others so that commands +// enqueued to it won't start until commands in the other devices have +// completed. +static void sync_with_other_backends(ggml_backend_opencl_context * backend_ctx) { + if (g_ggml_backend_opencl_devices.size() < 2) + return; // No other devices to synchronize with. + + std::vector events; + events.reserve(g_ggml_backend_opencl_devices.size()); + + for (ggml_backend_device & backend_dev : g_ggml_backend_opencl_devices) { + auto * other_backend_ctx = ggml_cl2_init(&backend_dev); + if (backend_ctx != other_backend_ctx) { + cl_event ev; + CL_CHECK(clEnqueueMarkerWithWaitList(other_backend_ctx->queue, 0, nullptr, &ev)); + CL_CHECK(clFlush(other_backend_ctx->queue)); + events.push_back(ev); + } + } + + CL_CHECK(clEnqueueBarrierWithWaitList(backend_ctx->queue, events.size(), events.data(), nullptr)); + for (auto ev : events) { + CL_CHECK(clReleaseEvent(ev)); + } +} + +static void sync_with_other_backends(ggml_backend_t backend) { + auto * backend_ctx = static_cast(backend->context); + sync_with_other_backends(backend_ctx); +} + static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; + // NOTE: this may oversynchronize by synchronizing with + // backends/devices which don't compute 'cgraph's + // dependencies. + sync_with_other_backends(backend); + if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { continue; } @@ -1729,6 +1967,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_ADD: case GGML_OP_SCALE: case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_SUB: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { @@ -1736,7 +1976,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_GELU_QUICK: - return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + case GGML_UNARY_OP_SIGMOID: + return ggml_is_contiguous(op->src[0]); default: return false; } @@ -1746,11 +1988,13 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_NORM: case GGML_OP_RMS_NORM: return true; + case GGML_OP_GROUP_NORM: + return ggml_is_contiguous(op->src[0]); case GGML_OP_MUL_MAT: if (op->src[0]->type == GGML_TYPE_F16) { return true; } else if (op->src[0]->type == GGML_TYPE_F32) { - return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); + return op->src[1]->type == GGML_TYPE_F32; } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q6_K) { return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); @@ -1785,6 +2029,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te } case GGML_OP_IM2COL: return true; + case GGML_OP_ARGSORT: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_SUM_ROWS: + return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); default: return false; } @@ -2058,15 +2306,16 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, // The original tensor memory is divided into scales and quants, i.e., // we first store scales, then quants. // Create subbuffer for scales. - region.origin = extra_orig->offset + tensor->view_offs + offset; + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); region.size = size_d; extra->d = clCreateSubBuffer( extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); + auto previous_origin = region.origin; // Create subbuffer for quants. - region.origin = extra_orig->offset + tensor->view_offs + offset + size_d; + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); region.size = size_q; extra->q = clCreateSubBuffer( extra_orig->data_device, CL_MEM_READ_WRITE, @@ -2271,8 +2520,8 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, cl_context context = backend_ctx->context; cl_command_queue queue = backend_ctx->queue; - // Make sure all previously submitted commands are finished. - CL_CHECK(clFinish(queue)); + // Make sure all previously submitted commands in other devices are finished. + sync_with_other_backends(backend_ctx); #ifdef GGML_OPENCL_SOA_Q // In end-to-end runs, get_tensor is usually used to get back the logits, @@ -2376,13 +2625,8 @@ static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_b } static size_t ggml_backend_opencl_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) { - // FIXME: not thread safe, device may not be initialized yet - static cl_uint alignment = -1; - if (alignment == (cl_uint)-1) { - ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device); - alignment = backend_ctx->alignment; - } - return alignment; + ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device); + return backend_ctx->alignment; } static size_t ggml_backend_opencl_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) { @@ -2409,16 +2653,6 @@ static ggml_backend_buffer_type_i ggml_backend_opencl_buffer_type_interface = { /* .is_host = */ NULL, }; -ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type() { - static ggml_backend_buffer_type buffer_type = { - /* .iface = */ ggml_backend_opencl_buffer_type_interface, - /* .device = */ &g_ggml_backend_opencl_device, - /* .context = */ nullptr, - }; - - return &buffer_type; -} - // // backend device // @@ -2476,9 +2710,15 @@ static ggml_backend_t ggml_backend_opencl_device_init(ggml_backend_dev_t dev, co } static ggml_backend_buffer_type_t ggml_backend_opencl_device_get_buffer_type(ggml_backend_dev_t dev) { - return ggml_backend_opencl_buffer_type(); + auto * dev_ctx = static_cast(dev->context); - GGML_UNUSED(dev); + dev_ctx->buffer_type = ggml_backend_buffer_type{ + /* .iface = */ ggml_backend_opencl_buffer_type_interface, + /* .device = */ dev, + /* .context = */ nullptr, + }; + + return &dev_ctx->buffer_type; } static ggml_backend_buffer_t ggml_backend_opencl_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { @@ -2494,12 +2734,21 @@ static bool ggml_backend_opencl_device_supports_op(ggml_backend_dev_t dev, const } static bool ggml_backend_opencl_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { - return buft->iface.get_name == ggml_backend_opencl_buffer_type_get_name; + // Check 'dev' and 'buffer_type' are not objects belonging to this backend. + if (dev->iface.get_name != ggml_backend_opencl_device_get_name || + buft->iface.get_name != ggml_backend_opencl_buffer_type_get_name) { + return false; + } - GGML_UNUSED(dev); + // Check cl_context is the same. clEnqueue* commands may not use + // buffers from another cl_context. + ggml_backend_opencl_context * backend_ctx0 = ggml_cl2_init(dev); + ggml_backend_opencl_context * backend_ctx1 = ggml_cl2_init(buft->device); + return backend_ctx0->context == backend_ctx1->context; } -static struct ggml_backend_device_i ggml_backend_opencl_device_i = { +namespace /* anonymous */ { +struct ggml_backend_device_i ggml_backend_opencl_device_i = { /* .get_name = */ ggml_backend_opencl_device_get_name, /* .get_description = */ ggml_backend_opencl_device_get_description, /* .get_memory = */ ggml_backend_opencl_device_get_memory, @@ -2516,6 +2765,7 @@ static struct ggml_backend_device_i ggml_backend_opencl_device_i = { /* .event_free = */ NULL, /* .event_synchronize = */ NULL, }; +} // Backend registry @@ -2526,15 +2776,15 @@ static const char * ggml_backend_opencl_reg_get_name(ggml_backend_reg_t reg) { } static size_t ggml_backend_opencl_reg_device_count(ggml_backend_reg_t reg) { - return ggml_backend_opencl_n_devices; + return g_ggml_backend_opencl_devices.size(); GGML_UNUSED(reg); } static ggml_backend_dev_t ggml_backend_opencl_reg_device_get(ggml_backend_reg_t reg, size_t index) { - GGML_ASSERT(index == 0); + GGML_ASSERT(index < ggml_backend_opencl_reg_device_count(reg)); - return &g_ggml_backend_opencl_device; + return &g_ggml_backend_opencl_devices[index]; GGML_UNUSED(reg); GGML_UNUSED(index); @@ -2548,27 +2798,23 @@ static struct ggml_backend_reg_i ggml_backend_opencl_reg_i = { }; ggml_backend_reg_t ggml_backend_opencl_reg(void) { - // TODO: make this thread-safe somehow? + static std::mutex mutex; static ggml_backend_reg reg; static bool initialized = false; + std::lock_guard lock(mutex); - if (!initialized) { - reg = ggml_backend_reg { - /* .api_version = */ GGML_BACKEND_API_VERSION, - /* .iface = */ ggml_backend_opencl_reg_i, - /* .context = */ NULL, - }; - - g_ggml_backend_opencl_device = ggml_backend_device { - /* .iface = */ ggml_backend_opencl_device_i, - /* .reg = */ ®, - /* .context = */ &g_ggml_ctx_dev_main, - }; - - ggml_cl2_init(&g_ggml_backend_opencl_device); - - initialized = true; + if (initialized) { + return ® } + initialized = true; + + g_ggml_backend_opencl_devices = ggml_opencl_probe_devices(®); + + reg = ggml_backend_reg{ + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_opencl_reg_i, + /* .context = */ NULL, + }; return ® } @@ -2942,14 +3188,19 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const size_t global_work_size[] = {(size_t)n, 1, 1}; size_t local_work_size[] = {64, 1, 1}; + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } + #ifdef GGML_OPENCL_PROFILING cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); g_profiling_info.emplace_back(); - populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst); #else - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); #endif } else { unsigned int nth = MIN(64, ne0); @@ -3072,6 +3323,261 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3)); } + if (bcast_row) { + int n = ggml_nelements(dst)/4; + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); +#endif + } else { + unsigned int nth = MIN(64, ne0); + size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {nth, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } +} + +static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const cl_ulong nb10 = src1->nb[0]; + const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb12 = src1->nb[2]; + const cl_ulong nb13 = src1->nb[3]; + + const int ne0 = dst->ne[0]; + + const cl_ulong nb0 = dst->nb[0]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + bool bcast_row = false; + cl_kernel kernel; + + if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + GGML_ASSERT(ggml_is_contiguous(src0)); + + // src1 is a row + GGML_ASSERT(ne11 == 1); + + bcast_row = true; + int ne = ne00 / 4; + kernel = backend_ctx->kernel_div_row; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne)); + } else { + kernel = backend_ctx->kernel_div; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3)); + } + + if (bcast_row) { + int n = ggml_nelements(dst)/4; + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } else { + unsigned int nth = MIN(64, ne0); + size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {nth, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } +} + +static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const cl_ulong nb10 = src1->nb[0]; + const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb12 = src1->nb[2]; + const cl_ulong nb13 = src1->nb[3]; + + const int ne0 = dst->ne[0]; + + const cl_ulong nb0 = dst->nb[0]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + bool bcast_row = false; + cl_kernel kernel; + + if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + GGML_ASSERT(ggml_is_contiguous(src0)); + + // src1 is a row + GGML_ASSERT(ne11 == 1); + + bcast_row = true; + int ne = ne00 / 4; + kernel = backend_ctx->kernel_sub_row; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne)); + } else { + kernel = backend_ctx->kernel_sub; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3)); + } + if (bcast_row) { int n = ggml_nelements(dst)/4; size_t global_work_size[] = {(size_t)n, 1, 1}; @@ -3233,14 +3739,19 @@ static void ggml_cl_silu(ggml_backend_t backend, const ggml_tensor * src0, const size_t global_work_size[] = {(size_t)n, 1, 1}; size_t local_work_size[] = {64, 1, 1}; + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } + #ifdef GGML_OPENCL_PROFILING cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); g_profiling_info.emplace_back(); - populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst); #else - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); #endif } @@ -3273,14 +3784,71 @@ static void ggml_cl_relu(ggml_backend_t backend, const ggml_tensor * src0, const size_t global_work_size[] = {(size_t)n, 1, 1}; size_t local_work_size[] = {64, 1, 1}; + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } + #ifdef GGML_OPENCL_PROFILING cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); g_profiling_info.emplace_back(); - populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst); #else - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_sigmoid(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_sigmoid_f32; + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + kernel = backend_ctx->kernel_sigmoid_f16; + } else { + GGML_ASSERT(false && "Unsupported data types for sigmoid (input and output must be both f32 or f16)"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + const int64_t n = ggml_nelements(dst); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); #endif } @@ -3320,14 +3888,19 @@ static void ggml_cl_clamp(ggml_backend_t backend, const ggml_tensor * src0, cons size_t global_work_size[] = {(size_t)n, 1, 1}; size_t local_work_size[] = {64, 1, 1}; + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } + #ifdef GGML_OPENCL_PROFILING cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); g_profiling_info.emplace_back(); - populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst); #else - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); #endif } @@ -3476,6 +4049,65 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c #endif } +static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + int32_t n_groups = ((const int32_t *) dst->op_params)[0]; + int32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + n_groups - 1) / n_groups); + float eps = ((const float *) dst->op_params)[1]; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne = ne00*ne01*ne02; + + cl_kernel kernel = backend_ctx->kernel_group_norm; + + size_t sgs = 64; + if (backend_ctx->gpu_family == ADRENO) { + sgs = 64; + } else if (backend_ctx->gpu_family == INTEL) { + sgs = 32; + } else { + GGML_ASSERT(false && "Unsupported GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &group_size)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float), &eps)); + + size_t global_work_size[] = {(size_t)n_groups*sgs, 1, 1}; + size_t local_work_size[] = {(size_t)sgs, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -4230,14 +4862,19 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons size_t global_work_size[] = {(size_t)n, 1, 1}; size_t local_work_size[] = {64, 1, 1}; + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } + #ifdef GGML_OPENCL_PROFILING cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); g_profiling_info.emplace_back(); - populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst); #else - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); #endif } @@ -4418,14 +5055,19 @@ static void ggml_cl_diag_mask_inf(ggml_backend_t backend, const ggml_tensor * sr size_t global_work_size[] = {(size_t)ne00, (size_t)ne01, (size_t)ne02}; size_t local_work_size[] = {64, 1, 1}; + size_t * local_work_size_ptr = local_work_size; + if (ne00 % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } + #ifdef GGML_OPENCL_PROFILING cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); g_profiling_info.emplace_back(); - populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst); #else - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); #endif } } @@ -4815,6 +5457,124 @@ static void ggml_cl_im2col(ggml_backend_t backend, const ggml_tensor * src0, con #endif } +static void ggml_cl_argsort(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_UNUSED(src1); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int nrows = ggml_nrows(src0); + + int ne00_padded = 1; + while (ne00_padded < ne00) { + ne00_padded *= 2; + } + + int order = (enum ggml_sort_order) dst->op_params[0]; + + cl_kernel kernel = backend_ctx->kernel_argsort_f32_i32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne00_padded)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &order)); + CL_CHECK(clSetKernelArg(kernel, 7, ne00_padded*sizeof(int), NULL)); + + size_t global_work_size[] = {(size_t)ne00_padded, (size_t)nrows, (size_t)1}; + size_t local_work_size[] = {(size_t)ne00_padded, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_UNUSED(src1); + + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + GGML_ASSERT(ggml_is_contiguous(src0)); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; + + cl_kernel kernel = backend_ctx->kernel_sum_rows_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3)); + + size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + //------------------------------------------------------------------------------ // Op offloading //------------------------------------------------------------------------------ @@ -4863,6 +5623,18 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_mul; break; + case GGML_OP_DIV: + if (!any_on_device) { + return false; + } + func = ggml_cl_div; + break; + case GGML_OP_SUB: + if (!any_on_device) { + return false; + } + func = ggml_cl_sub; + break; case GGML_OP_UNARY: switch (ggml_get_unary_op(tensor)) { case GGML_UNARY_OP_GELU: @@ -4889,6 +5661,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_relu; break; + case GGML_UNARY_OP_SIGMOID: + if (!any_on_device) { + return false; + } + func = ggml_cl_sigmoid; + break; default: return false; } break; @@ -4910,6 +5688,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_rms_norm; break; + case GGML_OP_GROUP_NORM: + if (!any_on_device) { + return false; + } + func = ggml_cl_group_norm; + break; case GGML_OP_MUL_MAT: if (!any_on_device && !ggml_cl_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) { return false; @@ -4955,6 +5739,18 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_im2col; break; + case GGML_OP_ARGSORT: + if (!any_on_device) { + return false; + } + func = ggml_cl_argsort; + break; + case GGML_OP_SUM_ROWS: + if (!any_on_device) { + return false; + } + func = ggml_cl_sum_rows; + break; default: return false; } diff --git a/ggml/src/ggml-opencl/kernels/argsort.cl b/ggml/src/ggml-opencl/kernels/argsort.cl new file mode 100644 index 000000000..af4adc7b8 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/argsort.cl @@ -0,0 +1,86 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define SWAP(x, y, T) { T tmp = (x); (x) = (y); (y) = tmp; } + +enum ggml_sort_order { + GGML_SORT_ORDER_ASC, + GGML_SORT_ORDER_DESC, +}; + +kernel void kernel_argsort_f32_i32( + global float * src0, + ulong offset0, + global int * dst, + ulong offsetd, + const int ne00, + const int ne00_pad, + const int order, + local int * dst_row +) { + // bitonic sort + int col = get_local_id(0); + int row = get_group_id(1); + + if (col >= ne00_pad) { + return; + } + + src0 = (global char *)((global char *)src0 + offset0); + dst = (global float *)((global char *)dst + offsetd); + + global float * x_row = src0 + row * ne00; + + // initialize indices + dst_row[col] = col; + + barrier(CLK_LOCAL_MEM_FENCE); + + for (int k = 2; k <= ne00_pad; k *= 2) { + for (int j = k / 2; j > 0; j /= 2) { + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (dst_row[col] >= ne00 || + (dst_row[ixj] < ne00 && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj], int); + } + } else { + if (dst_row[ixj] >= ne00 || + (dst_row[col] < ne00 && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj], int); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + } + + // copy the result to dst without the padding + if (col < ne00) { + dst[row * ne00 + col] = dst_row[col]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/div.cl b/ggml/src/ggml-opencl/kernels/div.cl new file mode 100644 index 000000000..d453ad99b --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/div.cl @@ -0,0 +1,72 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// div +//------------------------------------------------------------------------------ +kernel void kernel_div( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int i13 = i03 % ne13; + int i12 = i02 % ne12; + int i11 = i01 % ne11; + + global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const int i10 = i0 % ne10; + *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) / *((global float *)(src1_ptr + i10*nb10)); + } +} + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_div_row( + global float4 * src0, + ulong offset0, + global float4 * src1, + ulong offset1, + global float4 * dst, + ulong offsetd, + int ne +) { + src0 = (global float4*)((global char*)src0 + offset0); + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float4*)((global char*)dst + offsetd); + + // This performs better than using %. + uint gid = get_global_id(0); + uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne + dst[gid] = src0[gid] / src1[idx1]; +} diff --git a/ggml/src/ggml-opencl/kernels/group_norm.cl b/ggml/src/ggml-opencl/kernels/group_norm.cl new file mode 100644 index 000000000..57c9df4d3 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/group_norm.cl @@ -0,0 +1,72 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +// Workgroup must be a subgroup +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_32 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_group_norm( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne, + int group_size, + float eps +) { + src0 = (global float *)((global char *)src0 + offset0); + dst = (global float *)((global char *)dst + offsetd); + + int start = get_group_id(0) * group_size; + int end = start + group_size; + + start += get_local_id(0); + + if (end >= ne) { + end = ne; + } + + float tmp = 0.0f; + + for (int j = start; j < end; j += get_local_size(0)) { + tmp += src0[j]; + } + + tmp = sub_group_reduce_add(tmp); + + const float mean = tmp / group_size; + tmp = 0.0f; + + for (int j = start; j < end; j += get_local_size(0)) { + float xi = src0[j] - mean; + dst[j] = xi; + tmp += xi * xi; + } + + tmp = sub_group_reduce_add(tmp); + + const float variance = tmp / group_size; + const float scale = 1.0f/sqrt(variance + eps); + for (int j = start; j < end; j += get_local_size(0)) { + dst[j] *= scale; + } +} diff --git a/ggml/src/ggml-opencl/kernels/sigmoid.cl b/ggml/src/ggml-opencl/kernels/sigmoid.cl new file mode 100644 index 000000000..e3f669dde --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/sigmoid.cl @@ -0,0 +1,29 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// sigmoid +//------------------------------------------------------------------------------ + +kernel void kernel_sigmoid_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = 1.0f / (1.0f + exp(-src0[get_global_id(0)])); +} + +kernel void kernel_sigmoid_f16( + global half * src0, + ulong offset0, + global half * dst, + ulong offsetd +) { + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = 1.0f / (1.0f + exp(-src0[get_global_id(0)])); +} diff --git a/ggml/src/ggml-opencl/kernels/sub.cl b/ggml/src/ggml-opencl/kernels/sub.cl new file mode 100644 index 000000000..041e88ad3 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/sub.cl @@ -0,0 +1,72 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// div +//------------------------------------------------------------------------------ +kernel void kernel_sub( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int i13 = i03 % ne13; + int i12 = i02 % ne12; + int i11 = i01 % ne11; + + global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const int i10 = i0 % ne10; + *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) - *((global float *)(src1_ptr + i10*nb10)); + } +} + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_sub_row( + global float4 * src0, + ulong offset0, + global float4 * src1, + ulong offset1, + global float4 * dst, + ulong offsetd, + int ne +) { + src0 = (global float4*)((global char*)src0 + offset0); + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float4*)((global char*)dst + offsetd); + + // This performs better than using %. + uint gid = get_global_id(0); + uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne + dst[gid] = src0[gid] - src1[idx1]; +} diff --git a/ggml/src/ggml-opencl/kernels/sum_rows.cl b/ggml/src/ggml-opencl/kernels/sum_rows.cl new file mode 100644 index 000000000..c5f7c570f --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/sum_rows.cl @@ -0,0 +1,39 @@ + +kernel void kernel_sum_rows_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global float *)((global char *)src0 + offset0); + dst = (global float *)((global char *)dst + offsetd); + + int i3 = get_global_id(2); + int i2 = get_global_id(1); + int i1 = get_global_id(0); + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); + global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + + float row_sum = 0; + + for (int i0 = 0; i0 < ne00; i0++) { + row_sum += src_row[i0]; + } + + dst_row[0] = row_sum; +} diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index a2e261248..2a0045bcc 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -13,7 +13,7 @@ elseif(SUPPORTS_SYCL) If you expected the oneAPI Release compiler, please install oneAPI & source it, like: source /opt/intel/oneapi/setvars.sh") else() - message(FATAL_ERROR, "C++ compiler lacks SYCL support.") + message(FATAL_ERROR "C++ compiler lacks SYCL support.") endif() message(STATUS "SYCL found") #todo: AOT @@ -170,7 +170,7 @@ else() target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_NVIDIA) elseif (GGML_SYCL_TARGET STREQUAL "AMD") if (NOT GGML_SYCL_DEVICE_ARCH) - message(ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.") + message(FATAL_ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.") endif() target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath_blas_rocblas) target_compile_options(ggml-sycl PRIVATE "-fsycl-targets=amdgcn-amd-amdhsa") diff --git a/ggml/src/ggml-sycl/binbcast.cpp b/ggml/src/ggml-sycl/binbcast.cpp index aaa94176f..0a3883ae1 100644 --- a/ggml/src/ggml-sycl/binbcast.cpp +++ b/ggml/src/ggml-sycl/binbcast.cpp @@ -1,74 +1,93 @@ #include "binbcast.hpp" -#include #include #include #include -#include "dpct/helper.hpp" #include "ggml.h" -template -static __dpct_inline__ void k_bin_bcast_contiguous(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1, - dst_t * dst, std::size_t num_elements, const sycl::nd_item<1> & it) { - auto element_id = it.get_global_id(0); - auto global_range = it.get_global_range(0); - for (; element_id < num_elements; element_id += global_range) { - auto src0_float_val = sycl::vec(src0[element_id]).template convert(); - auto src1_float_val = sycl::vec(src1[element_id]).template convert(); - float dst_val = bin_op(src0_float_val[0], src1_float_val[0]); - auto val_to_store = sycl::vec(dst_val).template convert(); - dst[element_id] = val_to_store; +template +static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, + int ne0, int ne1, int ne2, int ne3, + int ne10, int ne11, int ne12, int ne13, + /*int s0, */ int s1, int s2, int s3, + /*int s00,*/ int s01, int s02, int s03, + /*int s10,*/ int s11, int s12, int s13, + const sycl::nd_item<3> &item_ct1) { + const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1)); + const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + + item_ct1.get_local_id(0)) / + ne3; + const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + + item_ct1.get_local_id(0)) % + ne3; + + if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + return; + } + + const int i11 = i1 % ne11; + const int i12 = i2 % ne12; + const int i13 = i3 % ne13; + + const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; + const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; + const size_t i_dst = i3*s3 + i2*s2 + i1*s1; + + const src0_t * src0_row = src0 + i_src0; + const src1_t * src1_row = src1 + i_src1; + dst_t * dst_row = dst + i_dst; + + for (int i0 = i0s; i0 < ne0; + i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) { + const int i10 = i0 % ne10; + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); } } -template -static __dpct_inline__ void k_bin_bcast(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1, dst_t * dst, - int ne0, int ne1, int ne2, int ne3, int ne10, int ne11, int ne12, int ne13, - int s0, int s1, int s2, int s3, int s00, int s01, int s02, int s03, int s10, - int s11, int s12, int s13, std::size_t num_dst_elements, - const sycl::nd_item<1> & item_ct1) { - auto calculate_logical_index = - [](const std::array & dims, std::size_t element_id) __attribute__((always_inline))->std::array { - std::array logical_index; -#pragma unroll(4) - for (int i = 3; i >= 0; i--) { - logical_index[i] = element_id % dims[i]; - element_id /= dims[i]; - } - return logical_index; - }; +template +static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, + int ne0, int ne1, int ne2, int ne3, + int ne10, int ne11, int ne12, int ne13, + /*int s0, */ int s1, int s2, int s3, + /*int s00,*/ int s01, int s02, int s03, + /*int s10,*/ int s11, int s12, int s13, + const sycl::nd_item<3> &item_ct1) { - auto calculate_index = [](const std::array & dims, const std::array & strides, - const std::array & indices) __attribute__((always_inline)) - ->std::size_t { - std::size_t index = 0; -#pragma unroll(4) - for (int i = 0; i < 4; i++) { - auto index_i = indices[i]; - if (indices[i] >= dims[i]) { - index_i = indices[i] % dims[i]; - } - index += strides[i] * index_i; - } - return index; - }; + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); - auto element_id = item_ct1.get_global_id(0); - for (; element_id < num_dst_elements; element_id += item_ct1.get_global_range(0)) { - auto logical_index = calculate_logical_index({ ne3, ne2, ne1, ne0 }, element_id); - auto src_0_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s03, s02, s01, s00 }, logical_index); - auto src_1_index = calculate_index({ ne13, ne12, ne11, ne10 }, { s13, s12, s11, s10 }, logical_index); - auto dst_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s3, s2, s1, s0 }, logical_index); - auto src0_float_val = sycl::vec(src0[src_0_index]).template convert(); - auto src1_float_val = sycl::vec(src1[src_1_index]).template convert(); - float dst_val = bin_op(src0_float_val[0], src1_float_val[0]); - auto val_to_store = sycl::vec(dst_val).template convert(); - dst[dst_index] = val_to_store; + const int i3 = i/(ne2*ne1*ne0); + const int i2 = (i/(ne1*ne0)) % ne2; + const int i1 = (i/ne0) % ne1; + const int i0 = i % ne0; + + if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + return; } + + const int i11 = i1 % ne11; + const int i12 = i2 % ne12; + const int i13 = i3 % ne13; + + const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; + const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; + const size_t i_dst = i3*s3 + i2*s2 + i1*s1; + + const src0_t * src0_row = src0 + i_src0; + const src1_t * src1_row = src1 + i_src1; + dst_t * dst_row = dst + i_dst; + + const int i10 = i0 % ne10; + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); } -template struct bin_bcast_sycl { + +template +struct bin_bcast_sycl { template void operator()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11, @@ -77,73 +96,165 @@ template struct bin_bcast_sycl { const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0, const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous, const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) { - auto check_bcast_required = [](const std::array & src_dims, - const std::array & dst_dims) -> bool { - for (int i = 0; i < 4; i++) { - if (dst_dims[i] > src_dims[i]) { - return true; - } - } - return false; + int nr0 = ne10 / ne0; + int nr1 = ne11/ne1; + int nr2 = ne12/ne2; + int nr3 = ne13/ne3; + + int nr[4] = { nr0, nr1, nr2, nr3 }; + + // collapse dimensions until first broadcast dimension + int64_t cne[] = {ne0, ne1, ne2, ne3}; + int64_t cne0[] = {ne00, ne01, ne02, ne03}; + int64_t cne1[] = {ne10, ne11, ne12, ne13}; + size_t cnb[] = {nb0, nb1, nb2, nb3}; + size_t cnb0[] = {nb00, nb01, nb02, nb03}; + size_t cnb1[] = {nb10, nb11, nb12, nb13}; + auto collapse = [](int64_t cne[]) { + cne[0] *= cne[1]; + cne[1] = cne[2]; + cne[2] = cne[3]; + cne[3] = 1; }; - dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + auto collapse_nb = [](size_t cnb[], int64_t cne[]) { + cnb[1] *= cne[1]; + cnb[2] *= cne[2]; + cnb[3] *= cne[3]; + }; - GGML_ASSERT(nb0 % sizeof(dst_t) == 0); - GGML_ASSERT(nb1 % sizeof(dst_t) == 0); - GGML_ASSERT(nb2 % sizeof(dst_t) == 0); - GGML_ASSERT(nb3 % sizeof(dst_t) == 0); + if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) { + for (int i = 0; i < 4; i++) { + if (nr[i] != 1) { + break; + } + if (i > 0) { + collapse_nb(cnb, cne); + collapse_nb(cnb0, cne0); + collapse_nb(cnb1, cne1); + collapse(cne); + collapse(cne0); + collapse(cne1); + } + } + } + { + int64_t ne0 = cne[0]; + int64_t ne1 = cne[1]; + int64_t ne2 = cne[2]; + int64_t ne3 = cne[3]; - GGML_ASSERT(nb00 % sizeof(src0_t) == 0); - GGML_ASSERT(nb01 % sizeof(src0_t) == 0); - GGML_ASSERT(nb02 % sizeof(src0_t) == 0); - GGML_ASSERT(nb03 % sizeof(src0_t) == 0); + int64_t ne10 = cne1[0]; + int64_t ne11 = cne1[1]; + int64_t ne12 = cne1[2]; + int64_t ne13 = cne1[3]; - GGML_ASSERT(nb10 % sizeof(src1_t) == 0); - GGML_ASSERT(nb11 % sizeof(src1_t) == 0); - GGML_ASSERT(nb12 % sizeof(src1_t) == 0); - GGML_ASSERT(nb13 % sizeof(src1_t) == 0); + size_t nb0 = cnb[0]; + size_t nb1 = cnb[1]; + size_t nb2 = cnb[2]; + size_t nb3 = cnb[3]; - // dst strides in number of elements - size_t s0 = nb0 / sizeof(dst_t); - size_t s1 = nb1 / sizeof(dst_t); - size_t s2 = nb2 / sizeof(dst_t); - size_t s3 = nb3 / sizeof(dst_t); + size_t nb00 = cnb0[0]; + size_t nb01 = cnb0[1]; + size_t nb02 = cnb0[2]; + size_t nb03 = cnb0[3]; - // src1 strides in number of elements - size_t s10 = nb10 / sizeof(src0_t); - size_t s11 = nb11 / sizeof(src1_t); - size_t s12 = nb12 / sizeof(src1_t); - size_t s13 = nb13 / sizeof(src1_t); + size_t nb10 = cnb1[0]; + size_t nb11 = cnb1[1]; + size_t nb12 = cnb1[2]; + size_t nb13 = cnb1[3]; - // src0 strides in number of elements - size_t s00 = nb00 / sizeof(src0_t); - size_t s01 = nb01 / sizeof(src0_t); - size_t s02 = nb02 / sizeof(src0_t); - size_t s03 = nb03 / sizeof(src0_t); + size_t s0 = nb0 / sizeof(dst_t); + size_t s1 = nb1 / sizeof(dst_t); + size_t s2 = nb2 / sizeof(dst_t); + size_t s3 = nb3 / sizeof(dst_t); - std::size_t num_dst_elements = static_cast(ne0) * static_cast(ne1) * - static_cast(ne2) * static_cast(ne3); - std::size_t local_range = 256; - std::size_t global_range = ceil_div(num_dst_elements, local_range) * local_range; + size_t s10 = nb10 / sizeof(src1_t); + size_t s11 = nb11 / sizeof(src1_t); + size_t s12 = nb12 / sizeof(src1_t); + size_t s13 = nb13 / sizeof(src1_t); - bool needs_broadcasting = check_bcast_required({ ne00, ne01, ne02, ne03 }, { ne0, ne1, ne2, ne3 }) || - check_bcast_required({ ne10, ne11, ne12, ne13 }, { ne0, ne1, ne2, ne3 }); - bool all_contiguous = src0_is_contiguous && src1_is_contiguous && dst_is_contiguous; + size_t s00 = nb00 / sizeof(src0_t); + size_t s01 = nb01 / sizeof(src0_t); + size_t s02 = nb02 / sizeof(src0_t); + size_t s03 = nb03 / sizeof(src0_t); - if (! needs_broadcasting && all_contiguous) { - stream->submit([&](sycl::handler & cgh) { - cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) { - k_bin_bcast_contiguous(src0_dd, src1_dd, dst_dd, num_dst_elements, it); - }); - }); - } else { - stream->submit([&](sycl::handler & cgh) { - cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) { - k_bin_bcast(src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, ne10, ne11, ne12, ne13, s0, s1, - s2, s3, s00, s01, s02, s03, s10, s11, s12, s13, num_dst_elements, it); - }); - }); + GGML_UNUSED(s00); + + GGML_ASSERT(nb0 % sizeof(dst_t) == 0); + GGML_ASSERT(nb1 % sizeof(dst_t) == 0); + GGML_ASSERT(nb2 % sizeof(dst_t) == 0); + GGML_ASSERT(nb3 % sizeof(dst_t) == 0); + + GGML_ASSERT(nb00 % sizeof(src0_t) == 0); + GGML_ASSERT(nb01 % sizeof(src0_t) == 0); + GGML_ASSERT(nb02 % sizeof(src0_t) == 0); + GGML_ASSERT(nb03 % sizeof(src0_t) == 0); + + GGML_ASSERT(nb10 % sizeof(src1_t) == 0); + GGML_ASSERT(nb11 % sizeof(src1_t) == 0); + GGML_ASSERT(nb12 % sizeof(src1_t) == 0); + GGML_ASSERT(nb13 % sizeof(src1_t) == 0); + + GGML_ASSERT(s0 == 1); + GGML_ASSERT(s10 == 1); + + const int block_size = 128; + + int64_t hne0 = std::max(ne0/2LL, 1LL); + + sycl::range<3> block_dims(1, 1, 1); + block_dims[2] = std::min(hne0, block_size); + block_dims[1] = std::min( + ne1, block_size / (unsigned int)block_dims[2]); + block_dims[0] = std::min( + std::min( + ne2 * ne3, block_size / (unsigned int)block_dims[2] / + (unsigned int)block_dims[1]), + 64U); + + sycl::range<3> block_nums( + (ne2 * ne3 + block_dims[0] - 1) / block_dims[0], + (ne1 + block_dims[1] - 1) / block_dims[1], + (hne0 + block_dims[2] - 1) / block_dims[2]); + + if (block_nums[0] > 65535) { + // this is the maximum number of blocks in z direction, fallback to 1D grid kernel + int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size; + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * + sycl::range<3>(1, 1, block_size), + sycl::range<3>(1, 1, block_size)), + [=](sycl::nd_item<3> item_ct1) { + k_bin_bcast_unravel( + src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, + ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02, + s03, s11, s12, s13, item_ct1); + }); + } + } else { + /* + DPCT1049:16: The work-group size passed to the SYCL kernel may + exceed the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if + needed. + */ + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_bin_bcast(src0_dd, src1_dd, dst_dd, ne0, ne1, + ne2, ne3, ne10, ne11, ne12, ne13, + s1, s2, s3, s01, s02, s03, s11, s12, s13, + item_ct1); + }); + } } } }; @@ -208,32 +319,27 @@ inline void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, ggml_tensor *ds void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); ggml_sycl_op_add(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); ggml_sycl_op_sub(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); ggml_sycl_op_mul(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); ggml_sycl_op_div(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_repeat(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 60909dde7..15ee9dc69 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -13,8 +13,10 @@ #ifndef GGML_SYCL_COMMON_HPP #define GGML_SYCL_COMMON_HPP +#include #include #include +#include #include "dpct/helper.hpp" #include "ggml-sycl.h" @@ -44,11 +46,20 @@ extern int g_ggml_sycl_debug; extern int g_ggml_sycl_disable_optimize; extern int g_ggml_sycl_prioritize_dmmv; -#define GGML_SYCL_DEBUG(...) \ - do { \ - if (g_ggml_sycl_debug) \ - fprintf(stderr, __VA_ARGS__); \ - } while (0) +#if defined(__clang__) && __has_builtin(__builtin_expect) +// Hint the optimizer to pipeline the more likely following instruction in branches +# define LIKELY(expr) __builtin_expect(expr, true) +# define UNLIKELY(expr) __builtin_expect(expr, false) +#else +# define LIKELY(expr) (expr) +# define UNLIKELY(expr) (expr) +#endif + +#define GGML_SYCL_DEBUG(...) \ + do { \ + if (UNLIKELY(g_ggml_sycl_debug)) \ + fprintf(stderr, __VA_ARGS__); \ + } while (0) #define CHECK_TRY_ERROR(expr) \ [&]() { \ @@ -471,6 +482,19 @@ static __dpct_inline__ float warp_reduce_max(float x, return x; } +/* Helper for Computing the linear offset of a ggml_tensor given +per-dimension sizes, strides, and indices */ +template +__dpct_inline__ size_t calculate_offset(const std::array & strides, const std::array & indices) { + size_t offset = 0; +#pragma unroll + for (int i = 0; i < N; i++) { + auto index_i = indices[i]; + offset += strides[i] * index_i; + } + return offset; +} + // Helper for vec loading aligned data template inline sycl::vec vec_aligned_load(const Tp* aligned_ptr) { @@ -490,4 +514,76 @@ constexpr size_t ceil_div(const size_t m, const size_t n) { } bool gpu_has_xmx(sycl::device &dev); + +template void debug_print_array(const std::string & prefix, const T array[N]) { + if (LIKELY(!g_ggml_sycl_debug)) { + return; + } + std::stringstream ss; + ss << prefix << "=["; + for (std::size_t i = 0; i < N - 1; ++i) { + ss << array[i] << ", "; + } + if constexpr (N > 0) { + ss << array[N - 1]; + } + ss << "]"; + GGML_SYCL_DEBUG("%s", ss.str().c_str()); +} + +inline void debug_print_tensor(const std::string & prefix, const ggml_tensor * tensor, + const std::string & suffix = "") { + if (LIKELY(!g_ggml_sycl_debug)) { + return; + } + GGML_SYCL_DEBUG("%s=", prefix.c_str()); + if (tensor) { + GGML_SYCL_DEBUG("'%s':type=%s", tensor->name, ggml_type_name(tensor->type)); + debug_print_array(";ne", tensor->ne); + debug_print_array(";nb", tensor->nb); + if (!ggml_is_contiguous(tensor)) { + GGML_SYCL_DEBUG(";strided"); + } + if (ggml_is_permuted(tensor)) { + GGML_SYCL_DEBUG(";permuted"); + } + } else { + GGML_SYCL_DEBUG("nullptr"); + } + GGML_SYCL_DEBUG("%s", suffix.c_str()); +} + +// Use scope_op_debug_print to log operations coming from running a model +struct scope_op_debug_print { + // Use string_views to avoid the cost of creating a string and concatenating them + // string_views must be alive for as long as the object is alive + // scope_op_debug_print are used with string literals in practice which are stored in constant space so always accessible + scope_op_debug_print(const std::string_view & func, const std::string_view & func_suffix, const ggml_tensor * dst, + std::size_t num_src, const std::string_view & suffix = "") : + func(func), + func_suffix(func_suffix) { + if (LIKELY(!g_ggml_sycl_debug)) { + return; + } + GGML_SYCL_DEBUG("[SYCL][OP] call %s%s:", func.data(), func_suffix.data()); + debug_print_tensor(" dst", dst); + if (dst) { + for (std::size_t i = 0; i < num_src; ++i) { + debug_print_tensor("\tsrc" + std::to_string(i), dst->src[i]); + } + } + GGML_SYCL_DEBUG("%s\n", suffix.data()); + } + + scope_op_debug_print(const std::string_view & func, const ggml_tensor * dst, std::size_t num_src, + const std::string_view & suffix = "") : + scope_op_debug_print(func, "", dst, num_src, suffix) {} + + ~scope_op_debug_print() { GGML_SYCL_DEBUG("[SYCL][OP] call %s%s done\n", func.data(), func_suffix.data()); } + + private: + std::string_view func; + std::string_view func_suffix; +}; + #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/concat.cpp b/ggml/src/ggml-sycl/concat.cpp index d41cfd3a6..7aa91c861 100644 --- a/ggml/src/ggml-sycl/concat.cpp +++ b/ggml/src/ggml-sycl/concat.cpp @@ -159,39 +159,37 @@ static void concat_f32_sycl_non_cont( } void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - const ggml_tensor *src0 = dst->src[0]; - const ggml_tensor *src1 = dst->src[1]; - queue_ptr stream = ctx.stream(); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + queue_ptr stream = ctx.stream(); - const int32_t dim = ((int32_t *)dst->op_params)[0]; + const int32_t dim = ((int32_t *) dst->op_params)[0]; - if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { - const float *src0_d = (const float *)src0->data; - const float *src1_d = (const float *)src1->data; + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; - float *dst_d = (float *)dst->data; + float * dst_d = (float *) dst->data; - if (dim != 3) { - for (int i3 = 0; i3 < dst->ne[3]; i3++) { - concat_f32_sycl( - src0_d + i3 * (src0->nb[3] / 4), src1_d + i3 * (src1->nb[3] / 4), - dst_d + i3 * (dst->nb[3] / 4), src0->ne[0], src0->ne[1], - src0->ne[2], dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); - } + if (dim != 3) { + for (int i3 = 0; i3 < dst->ne[3]; i3++) { + concat_f32_sycl(src0_d + i3 * (src0->nb[3] / 4), src1_d + i3 * (src1->nb[3] / 4), + dst_d + i3 * (dst->nb[3] / 4), src0->ne[0], src0->ne[1], src0->ne[2], dst->ne[0], + dst->ne[1], dst->ne[2], dim, stream); + } + } else { + const size_t size0 = ggml_nbytes(src0); + const size_t size1 = ggml_nbytes(src1); + + SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d, src0_d, size0).wait())); + SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d + size0 / 4, src1_d, size1).wait())); + } } else { - const size_t size0 = ggml_nbytes(src0); - const size_t size1 = ggml_nbytes(src1); - - SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d, src0_d, size0).wait())); - SYCL_CHECK(CHECK_TRY_ERROR( - stream->memcpy(dst_d + size0 / 4, src1_d, size1).wait())); + concat_f32_sycl_non_cont(stream, (const char *) src0->data, (const char *) src1->data, (char *) dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], + src0->nb[2], src0->nb[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], + dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim); } - } else - concat_f32_sycl_non_cont( - stream, (const char *)src0->data, (const char *)src1->data, - (char *)dst->data, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], - src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src1->ne[0], - src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1], - src1->nb[2], src1->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], - dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim); } diff --git a/ggml/src/ggml-sycl/conv.cpp b/ggml/src/ggml-sycl/conv.cpp index ddba601e1..475bd34a2 100644 --- a/ggml/src/ggml-sycl/conv.cpp +++ b/ggml/src/ggml-sycl/conv.cpp @@ -72,6 +72,7 @@ static void conv_transpose_1d_f32_f32_sycl( } void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); const ggml_tensor *src0 = dst->src[0]; const ggml_tensor *src1 = dst->src[1]; const float * src0_d = (const float *)src0->data; diff --git a/ggml/src/ggml-sycl/cpy.cpp b/ggml/src/ggml-sycl/cpy.cpp index 5a2314589..44487c256 100644 --- a/ggml/src/ggml-sycl/cpy.cpp +++ b/ggml/src/ggml-sycl/cpy.cpp @@ -616,6 +616,9 @@ static void ggml_cpy_i32_i32_sycl(const char * cx, char * cdst, const int ne, co } void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1) try { + // Unlike other operators ggml_sycl_cpy takes 2 distinct tensors instead of a dst ggml_tensor and rely on its src field + scope_op_debug_print scope_dbg_print(__func__, src1, /*num_src=*/0, + std::string(" src0 type=") + ggml_type_name(src0->type)); const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); @@ -629,8 +632,6 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co char * src0_ddc = (char *) src0->data; char * src1_ddc = (char *) src1->data; - GGML_SYCL_DEBUG("[SYCL] %s: Tensor supplied: %s to %s\n", __func__, ggml_type_name(src0->type), - ggml_type_name(src1->type)); if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { ggml_cpy_f32_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, @@ -694,8 +695,6 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co } void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - // TODO: why do we pass dst as src1 here? - GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_cpy(ctx, dst->src[0], dst); - GGML_SYCL_DEBUG("[SYCL] call %s done\n", __func__); } diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index b58150c68..4f2760110 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -1092,6 +1092,8 @@ void ggml_sycl_op_dequantize_mul_mat_vec( src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; if (src1_convert_f16) { + scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2, + " : converting src1 to fp16"); src1_dfloat = src1_dfloat_a.alloc(ne00); const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst); GGML_ASSERT(to_fp16_sycl != nullptr); diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index becaac404..5b7c4f0b4 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -84,6 +84,15 @@ static void gelu_quick(const T *x, T *dst, int k, dst[i] = x[i] * (static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x[i]))); } +template +static void gelu_erf(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { + const T SQRT_2_INV = static_cast(0.70710678118654752440084436210484f); + for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) { + auto x_i = x[i]; + dst[i] = static_cast(0.5f) * x_i * (static_cast(1.0f) + sycl::erf(x_i * SQRT_2_INV)); + } +} + template static void tanh(const T *x, T *dst, int k, const sycl::nd_item<3> &item_ct1) { @@ -400,6 +409,20 @@ static void gelu_quick_sycl(const T *x, T *dst, const int k, }); } + +template +static void gelu_erf_sycl(const T *x, T *dst, const int k, + queue_ptr stream) { + const int num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + gelu_erf(x, dst, k, item_ct1); + }); +} + template static void tanh_sycl(const T *x, T *dst, const int k, queue_ptr stream) { @@ -816,6 +839,38 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor } } +inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + + inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { #if defined (GGML_SYCL_F16) GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); @@ -1391,146 +1446,126 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_sqrt(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_sin(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_cos(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_acc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); ggml_sycl_op_acc(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_gelu(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_silu(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_gelu_quick(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_gelu_erf(ctx, dst); } void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_tanh(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_relu(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_sigmoid(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_hardsigmoid(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_hardswish(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } - void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_exp(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_log(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_neg(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_step(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_leaky_relu(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_sqr(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_upscale(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_pad(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_clamp(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_sgn(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_abs(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_elu(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp index f4199d69d..bd40113f0 100644 --- a/ggml/src/ggml-sycl/element_wise.hpp +++ b/ggml/src/ggml-sycl/element_wise.hpp @@ -38,6 +38,8 @@ void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/getrows.cpp b/ggml/src/ggml-sycl/getrows.cpp index 64665be46..4a7712781 100644 --- a/ggml/src/ggml-sycl/getrows.cpp +++ b/ggml/src/ggml-sycl/getrows.cpp @@ -257,8 +257,7 @@ static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tens GGML_UNUSED(ctx); } -void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I32); GGML_ASSERT(dst->type == GGML_TYPE_F32); @@ -308,4 +307,3 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { GGML_ABORT("fatal error"); } } - diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index d05919781..bcd2ea536 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -346,6 +346,8 @@ static void * ggml_backend_sycl_buffer_get_base(ggml_backend_buffer_t buffer) { static enum ggml_status ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor *tensor) try { + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + debug_print_tensor(": tensor=", tensor, "\n"); ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context; if (tensor->view_src != NULL) { @@ -381,7 +383,9 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor *tensor, const void *data, size_t offset, size_t size) try { - + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + debug_print_tensor(": tensor=", tensor); + GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset); ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; ggml_sycl_set_device(ctx->device); auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue()); @@ -407,7 +411,9 @@ static void ggml_backend_sycl_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor *tensor, void *data, size_t offset, size_t size) try { - + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + debug_print_tensor(": tensor=", tensor); + GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset); ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; ggml_sycl_set_device(ctx->device); @@ -435,7 +441,12 @@ static bool ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor *src, ggml_tensor *dst) try { - if (ggml_backend_buffer_is_sycl(src->buffer)) { + bool is_cpy_supported = ggml_backend_buffer_is_sycl(src->buffer); + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + debug_print_tensor(": dst=", dst); + debug_print_tensor(" src=", src); + GGML_SYCL_DEBUG(" is_cpy_supported=%d\n", is_cpy_supported); + if (is_cpy_supported) { ggml_backend_sycl_buffer_context * src_ctx = (ggml_backend_sycl_buffer_context *)src->buffer->context; ggml_backend_sycl_buffer_context * dst_ctx = (ggml_backend_sycl_buffer_context *)dst->buffer->context; @@ -492,7 +503,8 @@ ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer, static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) try { - ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; + GGML_SYCL_DEBUG("[SYCL] call %s: size=%zu\n", __func__, buffer->size); + ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context; ggml_sycl_set_device(ctx->device); queue_ptr stream = ctx->stream; @@ -511,7 +523,9 @@ catch (sycl::exception const &exc) { static void ggml_backend_sycl_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { - GGML_SYCL_DEBUG(" [SYCL] call %s\n", __func__); + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + debug_print_tensor(": tensor=", tensor); + GGML_SYCL_DEBUG(" size=%zu offset=%zu value=%u\n", size, offset, value); ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context; SYCL_CHECK(ggml_sycl_set_device(ctx->device)); auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue()); @@ -789,6 +803,8 @@ static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buff static enum ggml_status ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor *tensor) try { + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + debug_print_tensor(": tensor=", tensor, "\n"); GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context; @@ -873,6 +889,9 @@ static void ggml_backend_sycl_split_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor *tensor, const void *data, size_t offset, size_t size) try { + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + debug_print_tensor(": tensor=", tensor); + GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset); // split tensors must always be set in their entirety at once GGML_ASSERT(offset == 0); GGML_ASSERT(size == ggml_nbytes(tensor)); @@ -926,6 +945,9 @@ static void ggml_backend_sycl_split_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor *tensor, void *data, size_t offset, size_t size) try { + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + debug_print_tensor(": tensor=", tensor); + GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset); // split tensors must always be set in their entirety at once GGML_ASSERT(offset == 0); GGML_ASSERT(size == ggml_nbytes(tensor)); @@ -2015,12 +2037,12 @@ inline void ggml_sycl_op_mul_mat_sycl( #else bool use_fp16 = false; #endif - if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && - use_fp16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && - dst->op_params[0] == GGML_PREC_DEFAULT) { - // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp16 path\n"); + if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && use_fp16 && ggml_is_contiguous(src0) && + row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { ggml_sycl_pool_alloc src0_as_f16(ctx.pool()); if (src0->type != GGML_TYPE_F16) { + scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2, + " : converting src0 to fp16"); const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type, dst); GGML_ASSERT(to_fp16_sycl != nullptr); size_t ne = row_diff*ne00; @@ -2033,6 +2055,8 @@ inline void ggml_sycl_op_mul_mat_sycl( ggml_sycl_pool_alloc src1_as_f16(ctx.pool()); if (src1->type != GGML_TYPE_F16) { + scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2, + " : converting src1 to fp16"); const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst); GGML_ASSERT(to_fp16_sycl != nullptr); size_t ne = src1_ncols*ne10; @@ -2049,6 +2073,8 @@ inline void ggml_sycl_op_mul_mat_sycl( DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt(), src0_ptr, DnnlGemmWrapper::to_dt(), dst_f16.get(), DnnlGemmWrapper::to_dt(), stream); + scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2, + " : converting dst to fp32"); const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream); } @@ -2064,21 +2090,25 @@ inline void ggml_sycl_op_mul_mat_sycl( src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16, dst_f16.get(), dpct::library_data_t::real_half, ldc, dpct::library_data_t::real_half))); + scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2, + " : converting dst to fp32"); const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); } - } - else { - // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n"); + } else { ggml_sycl_pool_alloc src0_ddq_as_f32(ctx.pool()); ggml_sycl_pool_alloc src1_ddq_as_f32(ctx.pool()); if (src0->type != GGML_TYPE_F32) { + scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2, + " : converting src0 to fp32"); const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type, dst); GGML_ASSERT(to_fp32_sycl != nullptr); src0_ddq_as_f32.alloc(row_diff*ne00); to_fp32_sycl(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream); } if (src1->type != GGML_TYPE_F32) { + scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2, + " : converting src1 to fp32"); const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type, dst); GGML_ASSERT(to_fp32_sycl != nullptr); src1_ddq_as_f32.alloc(src1_ncols*ne10); @@ -2114,8 +2144,7 @@ catch (sycl::exception const &exc) { std::exit(1); } -static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); dpct::queue_ptr main_stream = ctx.stream(); @@ -2167,8 +2196,7 @@ inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor *dst) sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream); } -inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); dpct::queue_ptr main_stream = ctx.stream(); @@ -2199,8 +2227,7 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * argsort_f32_i32_sycl(src0_dd, (int *) dst_dd, ncols, nrows, order, main_stream); } -inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_I32); @@ -2215,8 +2242,7 @@ inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor *ds argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream); } -inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx,ggml_tensor *dst) { - +inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); dpct::queue_ptr main_stream = ctx.stream(); @@ -2233,8 +2259,7 @@ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx,ggml_tens diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream); } -inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); dpct::queue_ptr main_stream = ctx.stream(); @@ -2421,6 +2446,8 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs); if (src1_on_device && src1_is_contiguous) { + scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst, + /*num_src=*/2, " : converting src1 to Q8_1"); quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream); /* DPCT1010:90: SYCL uses exceptions to report errors and does not @@ -2525,6 +2552,8 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten } if (convert_src1_to_q8_1 && !src1_is_contiguous) { + scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst, + /*num_src=*/2, " : converting src1 to Q8_1"); quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream); /* DPCT1010:92: SYCL uses exceptions to report errors and does @@ -2619,33 +2648,28 @@ catch (sycl::exception const &exc) { static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); ggml_sycl_op_get_rows(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_norm(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_rms_norm(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_l2_norm(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_group_norm(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, @@ -2773,6 +2797,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons // convert src1 to fp16 if (src1->type != GGML_TYPE_F16) { + scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_nc_sycl", dst, /*num_src=*/2, + " : converting src1 to fp16"); const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type); GGML_ASSERT(to_fp16_nc_sycl != nullptr); const int64_t ne_src1 = ggml_nelements(src1); @@ -3076,6 +3102,7 @@ static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * } static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer); int64_t min_compute_capability = INT_MAX; @@ -3153,7 +3180,6 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor constexpr bool convert_src1_to_q8_1 = false; ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1); } - GGML_SYCL_DEBUG("call %s done\n", __func__); } @@ -3224,6 +3250,7 @@ __dpct_inline__ static void k_copy_dst_from_contiguous( static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, ggml_tensor *dst) try { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3); const ggml_tensor *src0 = dst->src[0]; const ggml_tensor *src1 = dst->src[1]; GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers"); @@ -3392,37 +3419,45 @@ catch (sycl::exception const &exc) { } static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_scale(ctx, dst); } static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_diag_mask_inf(ctx, dst); } static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_pool2d(ctx, dst); } static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); ggml_sycl_op_im2col(ctx, dst); } static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); GGML_ASSERT(ggml_is_contiguous(dst->src[0])); ggml_sycl_op_sum(ctx, dst); } static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); GGML_ASSERT(ggml_is_contiguous(dst->src[0])); ggml_sycl_op_sum_rows(ctx, dst); } static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); GGML_ASSERT(ggml_is_contiguous(dst->src[0])); ggml_sycl_op_argsort(ctx, dst); } static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); GGML_ASSERT(ggml_is_contiguous(dst->src[0])); ggml_sycl_op_argmax(ctx, dst); } @@ -3508,6 +3543,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_UNARY_OP_GELU_QUICK: ggml_sycl_gelu_quick(ctx, dst); break; + case GGML_UNARY_OP_GELU_ERF: + ggml_sycl_gelu_erf(ctx, dst); + break; case GGML_UNARY_OP_TANH: ggml_sycl_tanh(ctx, dst); break; @@ -3716,6 +3754,9 @@ static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend, ggml_tensor *tensor, const void *data, size_t offset, size_t size) try { + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + debug_print_tensor(": tensor=", tensor); + GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset); ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; @@ -3734,13 +3775,16 @@ static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend, const ggml_tensor *tensor, void *data, size_t offset, size_t size) try { + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + debug_print_tensor(": tensor=", tensor); + GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset); ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type"); const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0); SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy( - data, (const char *)tensor->data + offset, size).wait())); + data, (const char *)tensor->data + offset, size))); } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ @@ -3752,7 +3796,13 @@ static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor *src, ggml_tensor *dst) try { ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; - if (dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && ggml_backend_buffer_is_sycl(src->buffer)) { + bool is_cpy_supported = dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && + ggml_backend_buffer_is_sycl(src->buffer); + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + debug_print_tensor(": dst=", dst); + debug_print_tensor(" src=", src); + GGML_SYCL_DEBUG(" is_cpy_supported=%d\n", is_cpy_supported); + if (is_cpy_supported) { /* DPCT1009:215: SYCL uses exceptions to report errors and does not use the error codes. The original code was commented out and a warning string @@ -3760,7 +3810,7 @@ static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend, */ const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0); SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy( - dst->data, src->data, ggml_nbytes(dst)).wait())); + dst->data, src->data, ggml_nbytes(dst)))); return true; } @@ -3773,6 +3823,7 @@ catch (sycl::exception const &exc) { } static void ggml_backend_sycl_synchronize(ggml_backend_t backend) try { + GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__); ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0); SYCL_CHECK(CHECK_TRY_ERROR((stream)->wait())); @@ -3809,11 +3860,43 @@ static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * syc } } +#ifdef GGML_SYCL_GRAPH +static bool check_graph_compatibility(ggml_cgraph * cgraph) { + if (ggml_sycl_info().device_count > 1) { + // A sycl_ex::command_graph object can only be created for a single device + GGML_LOG_INFO("%s: disabling SYCL graphs due to multiple devices\n", __func__); + return false; + } + + for (int i = 0; i < cgraph->n_nodes; i++) { + const ggml_op node_op = cgraph->nodes[i]->op; + switch (node_op) { + default: + break; + case GGML_OP_CONCAT: + // ggml_sycl_op_concat() does a blocking host wait after memcpy operations, + // but wait() can't be called on the events returned by a queue recording + // to a graph. + [[fallthrough]]; + case GGML_OP_MUL_MAT_ID: + // ggml_sycl_mul_mat_id() does a blocking host wait on the sycl queue after + // submitting a memcpy operation, but wait() can't be called on a queue that + // is recording to a graph. + GGML_LOG_INFO("%s: disabling SYCL graphs due to unsupported node type %s\n", __func__, + ggml_op_name(node_op)); + return false; + } + } + return true; +} +#endif + static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { auto * sycl_ctx = static_cast(backend->context); #ifdef GGML_SYCL_GRAPH - if (!g_ggml_sycl_disable_graph) { + bool use_sycl_graph = !g_ggml_sycl_disable_graph && check_graph_compatibility(cgraph); + if (use_sycl_graph) { const bool graph_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_limited_graph); if (!graph_support) { GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device); @@ -3874,7 +3957,7 @@ catch (sycl::exception const &exc) } static void ggml_backend_sycl_event_wait(ggml_backend_t backend, ggml_backend_event_t event) try { - + GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__); sycl::event* sycl_event = static_cast(event->context); if (ggml_backend_is_sycl(backend)) { @@ -4016,6 +4099,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SGN: @@ -4161,6 +4245,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g #endif case GGML_OP_NORM: case GGML_OP_RMS_NORM: + return true; case GGML_OP_L2_NORM: case GGML_OP_GROUP_NORM: return ggml_is_contiguous(op->src[0]); @@ -4172,14 +4257,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_SOFT_MAX: return true; case GGML_OP_ROPE: - { - const int mode = ((const int32_t *) op->op_params)[2]; - // mode is not used as a bitmask in practice, the various rope type modes are independent implementations - if (mode == GGML_ROPE_TYPE_MROPE) { - return false; - } - return true; - } case GGML_OP_IM2COL: return true; case GGML_OP_UPSCALE: @@ -4269,6 +4346,7 @@ static void ggml_backend_sycl_device_event_free(ggml_backend_dev_t dev, ggml_bac static void ggml_backend_sycl_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) try { GGML_UNUSED(dev); + GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__); sycl::event *sycl_event = static_cast(event->context); SYCL_CHECK(CHECK_TRY_ERROR(sycl_event->wait())); diff --git a/ggml/src/ggml-sycl/gla.cpp b/ggml/src/ggml-sycl/gla.cpp index eedb47486..879184fdd 100644 --- a/ggml/src/ggml-sycl/gla.cpp +++ b/ggml/src/ggml-sycl/gla.cpp @@ -76,6 +76,7 @@ static void gated_linear_attn_f32_kernel(const dpct::queue_ptr stream, u_int B, } void ggml_sycl_op_gated_linear_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/5); const float * k_d = static_cast(dst->src[0]->data); const float * v_d = static_cast(dst->src[1]->data); const float * r_d = static_cast(dst->src[2]->data); diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 23eeb74da..cb70f83a4 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -1059,8 +1059,10 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens case GGML_TYPE_Q4_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_k_q8_1_sycl\n"); reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } else { + GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_K_q8_1_sycl\n"); mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } break; diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp index 4e9f438b4..4ec141684 100644 --- a/ggml/src/ggml-sycl/norm.cpp +++ b/ggml/src/ggml-sycl/norm.cpp @@ -1,40 +1,50 @@ #include "norm.hpp" +#include "ggml-sycl/common.hpp" +#include "ggml-sycl/presets.hpp" -static void norm_f32(const float* x, float* dst, const int ncols, const float eps, - const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) { - const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + - item_ct1.get_local_id(1); - const int tid = item_ct1.get_local_id(2); +static void norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, + const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) { + + const int nrows = item_ct1.get_group_range(2); + const int nchannels = item_ct1.get_group_range(1); const int nthreads = item_ct1.get_local_range(2); + const int sample = item_ct1.get_group(0); + const int channel = item_ct1.get_group(1); + const int row = item_ct1.get_group(2); + + const int tid = item_ct1.get_local_id(2); const int nwarps = nthreads / WARP_SIZE; + + const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row}); + const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row}); + + x += strided_offset; + dst += packed_offset; + sycl::float2 mean_var = sycl::float2(0.f, 0.f); for (int col = tid; col < ncols; col += block_size) { - const float xi = x[row * ncols + col]; + const float xi = x[col]; mean_var.x() += xi; mean_var.y() += xi * xi; } // sum up partial sums mean_var = warp_reduce_sum(mean_var, item_ct1); - if (block_size > WARP_SIZE) { - - int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; - int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = mean_var; + if (block_size > WARP_SIZE) { + const auto sub_group = item_ct1.get_sub_group(); + const auto sg_id = sub_group.get_group_linear_id(); + const auto wi_in_sg = sub_group.get_local_linear_id(); + if (wi_in_sg == 0) { + s_sum[sg_id] = mean_var; } - /* - DPCT1118:0: SYCL group functions and algorithms must be encountered in - converged control flow. You may need to adjust the code. - */ item_ct1.barrier(sycl::access::fence_space::local_space); mean_var = 0.f; - size_t nreduce = nwarps / WARP_SIZE; + const size_t nreduce = ceil_div(nwarps, WARP_SIZE); for (size_t i = 0; i < nreduce; i += 1) { - mean_var += s_sum[lane_id + i * WARP_SIZE]; + mean_var += s_sum[wi_in_sg + i * WARP_SIZE]; } mean_var = warp_reduce_sum(mean_var, item_ct1); } @@ -44,7 +54,7 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep const float inv_std = sycl::rsqrt(var + eps); for (int col = tid; col < ncols; col += block_size) { - dst[row * ncols + col] = (x[row * ncols + col] - mean) * inv_std; + dst[col] = (x[col] - mean) * inv_std; } } @@ -135,39 +145,51 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con } } -static void rms_norm_f32(const float* x, float* dst, const int ncols, const float eps, - const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) { - const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + - item_ct1.get_local_id(1); - const int tid = item_ct1.get_local_id(2); +static void rms_norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, + const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) { + + const int nrows = item_ct1.get_group_range(2); + const int nchannels = item_ct1.get_group_range(1); + + const int sample = item_ct1.get_group(0); + const int channel = item_ct1.get_group(1); + const int row = item_ct1.get_group(2); + const int nthreads = item_ct1.get_local_range(2); + + const int tid = item_ct1.get_local_id(2); const int nwarps = nthreads / WARP_SIZE; + + const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row}); + const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row}); + + x += strided_offset; + dst += packed_offset; + + float tmp = 0.0f; // partial sum for thread in warp for (int col = tid; col < ncols; col += block_size) { - const float xi = x[row * ncols + col]; + const float xi = x[col]; tmp += xi * xi; } // sum up partial sums tmp = warp_reduce_sum(tmp, item_ct1); if (block_size > WARP_SIZE) { - - int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; - int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = tmp; + const auto sub_group = item_ct1.get_sub_group(); + const auto sg_id = sub_group.get_group_linear_id(); + const auto wi_in_sg = sub_group.get_local_linear_id(); + if (wi_in_sg == 0) { + s_sum[sg_id] = tmp; } - /* - DPCT1118:3: SYCL group functions and algorithms must be encountered in - converged control flow. You may need to adjust the code. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); - size_t nreduce = nwarps / WARP_SIZE; + const size_t nreduce = ceil_div(nwarps, WARP_SIZE); tmp = 0.f; for (size_t i = 0; i < nreduce; i += 1) { - tmp += s_sum[lane_id + i * WARP_SIZE]; + tmp += s_sum[wi_in_sg + i * WARP_SIZE]; } tmp = warp_reduce_sum(tmp, item_ct1); } @@ -176,7 +198,7 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa const float scale = sycl::rsqrt(mean + eps); for (int col = tid; col < ncols; col += block_size) { - dst[row * ncols + col] = scale * x[row * ncols + col]; + dst[col] = scale * x[col]; } } @@ -224,20 +246,20 @@ static void l2_norm_f32(const float* x, float* dst, const int ncols, const float } } -static void norm_f32_sycl(const float* x, float* dst, const int ncols, - const int nrows, const float eps, - queue_ptr stream, int device) { +static void norm_f32_sycl(const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, + const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, + const float eps, queue_ptr stream, int device) { + + const sycl::range<3> global_dims(nsamples, nchannels, nrows); GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); stream->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, - block_dims), + sycl::nd_range<3>(global_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - norm_f32(x, dst, ncols, eps, item_ct1, - nullptr, WARP_SIZE); + norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE); }); }); } @@ -252,15 +274,12 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols, */ stream->submit([&](sycl::handler& cgh) { sycl::local_accessor s_sum_acc_ct1( - sycl::range<1>(work_group_size / WARP_SIZE), cgh); - + sycl::range<1>(work_group_size / WARP_SIZE), cgh); cgh.parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, - block_dims), + sycl::nd_range<3>(global_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - norm_f32(x, dst, ncols, eps, item_ct1, - get_pointer(s_sum_acc_ct1), work_group_size); + norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size); }); }); } @@ -313,21 +332,20 @@ static void group_norm_f32_sycl(const float* x, float* dst, } } -static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, - const int nrows, const float eps, - queue_ptr stream, int device) { +static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples, + const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) { GGML_ASSERT(ncols % WARP_SIZE == 0); // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE); + + const sycl::range<3> global_dims(nsamples, nchannels, nrows); if (ncols < 1024) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); stream->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, - block_dims), + sycl::nd_range<3>(global_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - rms_norm_f32(x, dst, ncols, eps, item_ct1, - nullptr, WARP_SIZE); + rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE); }); }); } @@ -344,12 +362,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, sycl::local_accessor s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE), cgh); cgh.parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, - block_dims), + sycl::nd_range<3>(global_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - rms_norm_f32(x, dst, ncols, eps, item_ct1, - get_pointer(s_sum_acc_ct1), work_group_size); + rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size); }); }); } @@ -398,12 +414,12 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols, } void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { + const ggml_tensor * src0 = dst->src[0]; GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - const int64_t ne00 = dst->src[0]->ne[0]; - const int64_t nrows = ggml_nrows(dst->src[0]); + GGML_TENSOR_UNARY_OP_LOCALS dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); const float * src0_dd = static_cast(dst->src[0]->data); @@ -411,8 +427,14 @@ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { float eps; memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); + const size_t ts0 = ggml_type_size(src0->type); + GGML_ASSERT(nb00 == ts0); + const int64_t s01 = nb01 / ts0; + const int64_t s02 = nb02 / ts0; + const int64_t s03 = nb03 / ts0; - norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device); + norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device); } void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { @@ -436,11 +458,10 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - const int64_t ne00 = dst->src[0]->ne[0]; - const int64_t nrows = ggml_nrows(dst->src[0]); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); @@ -450,7 +471,13 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { float eps; memcpy(&eps, dst->op_params, sizeof(float)); - rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device); + GGML_TENSOR_UNARY_OP_LOCALS + const size_t ts0 = ggml_type_size(src0->type); + GGML_ASSERT(nb00 == ts0); + const int64_t s01 = nb01 / ts0; + const int64_t s02 = nb02 / ts0; + const int64_t s03 = nb03 / ts0; + rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device); } void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { diff --git a/ggml/src/ggml-sycl/outprod.cpp b/ggml/src/ggml-sycl/outprod.cpp index b60415784..3a17f3a1b 100644 --- a/ggml/src/ggml-sycl/outprod.cpp +++ b/ggml/src/ggml-sycl/outprod.cpp @@ -1,6 +1,7 @@ #include "outprod.hpp" void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); const ggml_tensor *src0 = dst->src[0]; const ggml_tensor *src1 = dst->src[1]; diff --git a/ggml/src/ggml-sycl/rope.cpp b/ggml/src/ggml-sycl/rope.cpp index 4e276d3b6..44473e1e5 100644 --- a/ggml/src/ggml-sycl/rope.cpp +++ b/ggml/src/ggml-sycl/rope.cpp @@ -49,10 +49,7 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const if (i0 >= n_dims) { const int i = row * ne0 + i0; - - dst[i + 0] = x[i + 0]; - dst[i + 1] = x[i + 1]; - + *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i); return; } @@ -93,10 +90,7 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const if (i0 >= n_dims) { const int i = row * ne0 + i0; - - dst[i + 0] = x[i + 0]; - dst[i + 1] = x[i + 1]; - + *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i); return; } @@ -122,6 +116,63 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta; } +template +static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1, + const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale, + const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, + const float theta_scale, const float * freq_factors, const mrope_sections sections, + const sycl::nd_item<3> & item_ct1) { + // get index pos + const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1)); + if (i0 >= ne0) { + return; + } + const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2); + + if (i0 >= n_dims) { + const int i = row_dst*ne0 + i0; + *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i); + return; + } + + const int row_x = row_dst % ne1; + const int channel_x = row_dst / ne1; + const int idst = (row_dst * ne0) + (i0 / 2); + const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2); + + const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; + const int sec_w = sections.v[1] + sections.v[0]; + const int sector = (i0 / 2) % sect_dims; + + + float theta_base = 0.0; + if (sector < sections.v[0]) { + theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f); + } + else if (sector >= sections.v[0] && sector < sec_w) { + theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f); + } + else if (sector >= sec_w && sector < sec_w + sections.v[2]) { + theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f); + } + else if (sector >= sec_w + sections.v[2]) { + theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f); + } + + const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; + float cos_theta; + float sin_theta; + rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + const float x0 = x[ix + 0]; + const float x1 = x[ix + n_dims/2]; + + // store results in dst + dst[idst + 0] = x0 * cos_theta - x1 * sin_theta; + dst[idst + n_dims/2] = x0 * sin_theta + x1 * cos_theta; +} + + + template static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1, const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale, @@ -171,7 +222,7 @@ static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, c const float * freq_factors, queue_ptr stream) { GGML_ASSERT(ne0 % 2 == 0); const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); - const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE); + const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE)); const sycl::range<3> block_nums(1, num_blocks_x, nr); const float theta_scale = powf(freq_base, -2.0f / n_dims); @@ -208,7 +259,7 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) { GGML_ASSERT(ne0 % 2 == 0); const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); - const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE); + const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE)); const sycl::range<3> block_nums(1, num_blocks_x, nr); const float theta_scale = powf(freq_base, -2.0f / n_dims); @@ -228,6 +279,40 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c } } +template +static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1, + const size_t s2, const int n_dims, const int nr, const int32_t * pos, + const float freq_scale, const float freq_base, const float ext_factor, + const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors, + const mrope_sections sections, queue_ptr stream) { + GGML_ASSERT(ne0 % 2 == 0); + const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); + const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE)); + const sycl::range<3> grid_dims(1, n_blocks_y, nr); + const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims); + + const float theta_scale = std::pow(freq_base, -2.0f / n_dims); + // Add FP16 capability check if T could be sycl::half + if constexpr (std::is_same_v) { + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + } + // launch kernel + if (freq_factors == nullptr) { + stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { + rope_multi(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, + corr_dims, theta_scale, freq_factors, sections, item_ct1); + }); + } else { + stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { + rope_multi(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, + corr_dims, theta_scale, freq_factors, sections, item_ct1); + }); + } +} + + + + // rope vision template static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1, @@ -237,7 +322,7 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const mrope_sections sections, queue_ptr stream) { GGML_ASSERT(ne0 % 2 == 0); const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); - const int n_blocks_y = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE); + const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE)); const sycl::range<3> grid_dims(1, n_blocks_y, nr); const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims); @@ -298,8 +383,17 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) memcpy(§ions.v, (int32_t *) dst->op_params + 11, sizeof(int)*4); const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + if (is_mrope) { + GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0); + } + + if (is_vision) { + GGML_ASSERT(n_dims == ne00/2); + } + const int32_t * pos = (const int32_t *) dst->src[1]->data; const float * freq_factors = nullptr; @@ -326,6 +420,19 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) } else { GGML_ABORT("fatal error"); } + } else if (is_mrope && !is_vision) { + GGML_SYCL_DEBUG("%s: mrope path\n", __func__); + if (dst->src[0]->type == GGML_TYPE_F16) { + rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01, + s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, + freq_factors, sections, main_stream); + } else if (dst->src[0]->type == GGML_TYPE_F32) { + rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims, + nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, + main_stream); + } else { + GGML_ABORT("Fatal error: Tensor type unsupported!"); + } } else if (is_vision) { GGML_SYCL_DEBUG("%s: vision path\n", __func__); if (dst->src[0]->type == GGML_TYPE_F16) { @@ -355,8 +462,7 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) } void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3); ggml_sycl_op_rope(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp index 7563d9ced..52fcf4b3d 100644 --- a/ggml/src/ggml-sycl/softmax.cpp +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -225,7 +225,7 @@ static void soft_max_f32_sycl(const float * x, const T * mask, } void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -249,16 +249,13 @@ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) { const sycl::half * src1_dd = static_cast(dst->src[1]->data); - GGML_SYCL_DEBUG("%s: F16 mask\n", __func__); soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); } else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) { const float * src1_dd = static_cast(dst->src[1]->data); - GGML_SYCL_DEBUG("%s: F32 mask\n", __func__); soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); } else { /* mask unavailable */ - GGML_SYCL_DEBUG("%s: No mask\n", __func__); soft_max_f32_sycl(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); } } diff --git a/ggml/src/ggml-sycl/tsembd.cpp b/ggml/src/ggml-sycl/tsembd.cpp index b877d18c1..f6ca626ea 100644 --- a/ggml/src/ggml-sycl/tsembd.cpp +++ b/ggml/src/ggml-sycl/tsembd.cpp @@ -56,8 +56,8 @@ static void timestep_embedding_f32_sycl( } void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - const ggml_tensor *src0 = dst->src[0]; - const ggml_tensor *src1 = dst->src[1]; + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; float * dst_d = (float *)dst->data; dpct::queue_ptr stream = ctx.stream(); @@ -69,5 +69,4 @@ void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tenso const int max_period = dst->op_params[1]; timestep_embedding_f32_sycl(src0_d, dst_d, src0->ne[0], dst->nb[1], dim, max_period, stream); - GGML_UNUSED(src1); } diff --git a/ggml/src/ggml-sycl/wkv.cpp b/ggml/src/ggml-sycl/wkv.cpp index 540f6fbf5..c10e2f764 100644 --- a/ggml/src/ggml-sycl/wkv.cpp +++ b/ggml/src/ggml-sycl/wkv.cpp @@ -180,10 +180,7 @@ static void rwkv_wkv7_f32_kernel( } void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { - - const ggml_tensor *src0 = dst->src[0]; - const ggml_tensor *src1 = dst->src[1]; - + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/6); const float* k_d = (const float*)dst->src[0]->data; const float* v_d = (const float*)dst->src[1]->data; const float* r_d = (const float*)dst->src[2]->data; @@ -236,16 +233,10 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { }); }); } - - GGML_UNUSED(src0); - GGML_UNUSED(src1); } void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { - - const ggml_tensor *src0 = dst->src[0]; - const ggml_tensor *src1 = dst->src[1]; - + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/7); const float* r_d = (const float*)dst->src[0]->data; const float* w_d = (const float*)dst->src[1]->data; const float* k_d = (const float*)dst->src[2]->data; @@ -299,7 +290,4 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { }); }); } - - GGML_UNUSED(src0); - GGML_UNUSED(src1); } diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 662f13771..4a88415f9 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -109,10 +109,6 @@ if (Vulkan_FOUND) add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO) endif() - if (GGML_VULKAN_PERF) - add_compile_definitions(GGML_VULKAN_PERF) - endif() - if (GGML_VULKAN_VALIDATE) add_compile_definitions(GGML_VULKAN_VALIDATE) endif() diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c160a9984..41d20aa5d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1,6 +1,6 @@ #include "ggml-vulkan.h" #include -#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF) || defined(GGML_VULKAN_CHECK_RESULTS) +#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_CHECK_RESULTS) #include #include "ggml-cpu.h" #endif @@ -184,9 +184,7 @@ static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = { #ifdef GGML_VULKAN_MEMORY_DEBUG class vk_memory_logger; #endif -#ifdef GGML_VULKAN_PERF class vk_perf_logger; -#endif static void ggml_vk_destroy_buffer(vk_buffer& buf); static constexpr uint32_t mul_mat_vec_max_cols = 8; @@ -442,9 +440,11 @@ struct vk_device_struct { #ifdef GGML_VULKAN_MEMORY_DEBUG std::unique_ptr memory_logger; #endif -#ifdef GGML_VULKAN_PERF + + // for GGML_VK_PERF_LOGGER std::unique_ptr perf_logger; -#endif + vk::QueryPool query_pool; + uint32_t num_queries; ~vk_device_struct() { VK_LOG_DEBUG("destroy device " << name); @@ -828,8 +828,6 @@ private: #define VK_LOG_MEMORY(msg) ((void) 0) #endif // GGML_VULKAN_MEMORY_DEBUG -#if defined(GGML_VULKAN_PERF) - class vk_perf_logger { public: void print_timings() { @@ -839,7 +837,7 @@ public: for (const auto& time : t.second) { total += time; } - std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " ms" << std::endl; + std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " us" << std::endl; } timings.clear(); @@ -868,7 +866,6 @@ public: private: std::map> timings; }; -#endif // GGML_VULKAN_PERF struct ggml_backend_vk_context { std::string name; @@ -958,6 +955,8 @@ struct vk_instance_t { static bool vk_instance_initialized = false; static vk_instance_t vk_instance; +static bool vk_perf_logger_enabled = false; + #ifdef GGML_VULKAN_CHECK_RESULTS static size_t vk_skip_checks; static size_t vk_output_tensor; @@ -1653,7 +1652,7 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t D, uint32_ return {64, 32}; } return {64, 64}; -}; +} static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { @@ -2757,9 +2756,9 @@ static vk_device ggml_vk_get_device(size_t idx) { #ifdef GGML_VULKAN_MEMORY_DEBUG device->memory_logger = std::unique_ptr(new vk_memory_logger()); #endif -#ifdef GGML_VULKAN_PERF - device->perf_logger = std::unique_ptr(new vk_perf_logger()); -#endif + if (vk_perf_logger_enabled) { + device->perf_logger = std::unique_ptr(new vk_perf_logger()); + } size_t dev_num = vk_instance.device_indices[idx]; @@ -2804,23 +2803,29 @@ static vk_device ggml_vk_get_device(size_t idx) { pipeline_robustness = true; } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) { device->subgroup_size_control = true; +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_COOPMAT")) { device->coopmat_support = true; device->coopmat_m = 0; device->coopmat_n = 0; device->coopmat_k = 0; +#endif +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_COOPMAT2")) { coopmat2_support = true; +#endif #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { device->integer_dot_product = true; #endif +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_BFLOAT16")) { bfloat16_support = true; +#endif } } @@ -3541,6 +3546,8 @@ static void ggml_vk_instance_init() { vk_instance.instance = vk::createInstance(instance_create_info); vk_instance_initialized = true; + vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr; + size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size(); // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan @@ -4670,6 +4677,19 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const } } + if (src->type == to) { + // Copy two or four bytes at a time, depending on block size. + // For quantized types, we scale by block size/type size. But + // this path is also used for bf16->bf16 for example, where the + // type size must be exactly 2 or 4. + GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4); + if ((ggml_type_size(src->type) % 4) == 0) { + return ctx->device->pipeline_contig_cpy_f32_f32; + } else { + return ctx->device->pipeline_contig_cpy_f16_f16; + } + } + std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl; GGML_ABORT("fatal error"); } @@ -6433,6 +6453,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_ROPE: case GGML_OP_RMS_NORM: case GGML_OP_CONV_2D_DW: + case GGML_OP_IM2COL: return true; default: return false; @@ -6731,7 +6752,16 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_UNARY: case GGML_OP_CONV_2D_DW: { - const uint32_t ne = ggml_nelements(dst); + uint32_t ne = ggml_nelements(dst); + if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) { + // Convert from number of logical elements to 2- or 4-byte units. + ne /= ggml_blck_size(src0->type); + if ((ggml_type_size(src0->type) % 4) == 0) { + ne *= ggml_type_size(src0->type) / 4; + } else { + ne *= ggml_type_size(src0->type) / 2; + } + } if (ne > 262144) { elements = { 512, 512, CEIL_DIV(ne, 262144) }; } else if (ne > 512) { @@ -7281,8 +7311,19 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t dst_type_size = ggml_type_size(dst->type); + uint32_t ne = (uint32_t)ggml_nelements(src0); + if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) { + // Convert from number of logical elements to 2- or 4-byte units. + ne /= ggml_blck_size(src0->type); + if ((ggml_type_size(src0->type) % 4) == 0) { + ne *= ggml_type_size(src0->type) / 4; + } else { + ne *= ggml_type_size(src0->type) / 2; + } + } + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, { - (uint32_t)ggml_nelements(src0), + ne, (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, @@ -8845,7 +8886,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod ctx->tensor_ctxs[node_idx] = compute_ctx; -#if defined(GGML_VULKAN_CHECK_RESULTS) || defined(GGML_VULKAN_PERF) +#if defined(GGML_VULKAN_CHECK_RESULTS) // Force context reset on each node so that each tensor ends up in its own context // and can be run and compared to its CPU equivalent separately last_node = true; @@ -9264,8 +9305,7 @@ static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_ try { ptr = ggml_vk_host_malloc(vk_instance.devices[0], size); } catch (vk::SystemError& e) { - std::cerr << "ggml_vulkan: Failed to allocate pinned memory." << std::endl; - std::cerr << "ggml_vulkan: " << e.what() << std::endl; + GGML_LOG_WARN("ggml_vulkan: Failed to allocate pinned memory (%s)\n", e.what()); // fallback to cpu buffer return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); } @@ -9466,6 +9506,29 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg bool first_node_in_batch = true; // true if next node will be first node in a batch int submit_node_idx = 0; // index to first node in a batch + vk_context compute_ctx; + if (vk_perf_logger_enabled) { + // allocate/resize the query pool + if (ctx->device->num_queries < cgraph->n_nodes + 1) { + if (ctx->device->query_pool) { + ctx->device->device.destroyQueryPool(ctx->device->query_pool); + } + VkQueryPoolCreateInfo query_create_info = { VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO }; + query_create_info.queryType = VK_QUERY_TYPE_TIMESTAMP; + query_create_info.queryCount = cgraph->n_nodes + 100; + ctx->device->query_pool = ctx->device->device.createQueryPool(query_create_info); + ctx->device->num_queries = query_create_info.queryCount; + } + + ctx->device->device.resetQueryPool(ctx->device->query_pool, 0, cgraph->n_nodes+1); + + GGML_ASSERT(ctx->compute_ctx.expired()); + compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0); + } + // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution. // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB // (and scaled down based on model size, so smaller models submit earlier). @@ -9493,6 +9556,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit); + if (vk_perf_logger_enabled) { + if (ctx->compute_ctx.expired()) { + compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + } else { + compute_ctx = ctx->compute_ctx.lock(); + } + compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+1); + } + if (enqueued) { ++submitted_nodes; @@ -9514,9 +9588,27 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg } } -#ifdef GGML_VULKAN_PERF - ctx->device->perf_logger->print_timings(); -#endif + if (vk_perf_logger_enabled) { + // End the command buffer and submit/wait + GGML_ASSERT(!ctx->compute_ctx.expired()); + compute_ctx = ctx->compute_ctx.lock(); + ggml_vk_ctx_end(compute_ctx); + + ggml_vk_submit(compute_ctx, ctx->device->fence); + VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), "GGML_VULKAN_PERF waitForFences"); + ctx->device->device.resetFences({ ctx->device->fence }); + + // Get the results and pass them to the logger + std::vector timestamps(cgraph->n_nodes + 1); + ctx->device->device.getQueryPoolResults(ctx->device->query_pool, 0, cgraph->n_nodes + 1, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait); + for (int i = 0; i < cgraph->n_nodes; i++) { + if (!ggml_vk_is_empty(cgraph->nodes[i])) { + ctx->device->perf_logger->log_timing(cgraph->nodes[i], uint64_t((timestamps[i+1] - timestamps[i]) * ctx->device->properties.limits.timestampPeriod)); + } + } + + ctx->device->perf_logger->print_timings(); + } ggml_vk_graph_cleanup(ctx); @@ -9867,6 +9959,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { return true; } + + // We can handle copying from a type to the same type if it's + // contiguous (memcpy). We use f16 or f32 shaders to do the copy, + // so the type/block size must be a multiple of 4. + if (src0_type == src1_type && + ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op) && + (ggml_type_size(src0_type) % 2) == 0) { + return true; + } return false; } break; case GGML_OP_REPEAT: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 57d3e39ad..196b7b8f3 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -133,7 +133,7 @@ static void ggml_print_backtrace_symbols(void) { } #endif -static void ggml_print_backtrace(void) { +void ggml_print_backtrace(void) { const char * GGML_NO_BACKTRACE = getenv("GGML_NO_BACKTRACE"); if (GGML_NO_BACKTRACE) { return; @@ -160,6 +160,10 @@ static void ggml_print_backtrace(void) { const int parent_pid = getpid(); const int child_pid = fork(); if (child_pid < 0) { // error +#if defined(__linux__) + close(lock[1]); + close(lock[0]); +#endif return; } else if (child_pid == 0) { // child char attach[32]; @@ -167,6 +171,7 @@ static void ggml_print_backtrace(void) { #if defined(__linux__) close(lock[1]); (void) !read(lock[0], lock, 1); + close(lock[0]); #endif // try gdb execlp("gdb", "gdb", "--batch", @@ -195,7 +200,7 @@ static void ggml_print_backtrace(void) { } } #else -static void ggml_print_backtrace(void) { +void ggml_print_backtrace(void) { // platform not supported } #endif @@ -216,6 +221,8 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) { abort(); } +// ggml_print_backtrace is registered with std::set_terminate by ggml.cpp + // // logging // @@ -2312,6 +2319,26 @@ struct ggml_tensor * ggml_repeat( return result; } +struct ggml_tensor * ggml_repeat_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) { + const bool can_repeat = ggml_is_empty(a) || ( + (ne0 % a->ne[0] == 0) && + (ne1 % a->ne[1] == 0) && + (ne2 % a->ne[2] == 0) && + (ne3 % a->ne[3] == 0) + ); + GGML_ASSERT(can_repeat); + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); + + result->op = GGML_OP_REPEAT; + result->src[0] = a; + + return result; +} + // ggml_repeat_back struct ggml_tensor * ggml_repeat_back( diff --git a/ggml/src/ggml.cpp b/ggml/src/ggml.cpp new file mode 100644 index 000000000..0d388d455 --- /dev/null +++ b/ggml/src/ggml.cpp @@ -0,0 +1,26 @@ +#include "ggml-impl.h" + +#include +#include + +static std::terminate_handler previous_terminate_handler; + +GGML_NORETURN static void ggml_uncaught_exception() { + ggml_print_backtrace(); + if (previous_terminate_handler) { + previous_terminate_handler(); + } + abort(); // unreachable unless previous_terminate_handler was nullptr +} + +static bool ggml_uncaught_exception_init = []{ + const char * GGML_NO_BACKTRACE = getenv("GGML_NO_BACKTRACE"); + if (GGML_NO_BACKTRACE) { + return false; + } + const auto prev{std::get_terminate()}; + GGML_ASSERT(prev != ggml_uncaught_exception); + previous_terminate_handler = prev; + std::set_terminate(ggml_uncaught_exception); + return true; +}(); diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index 8667a80bd..a0a318a29 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -347,11 +347,28 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par int64_t n_tensors = 0; if (ok && gr.read(ctx->version)) { - if (ctx->version == 1) { + if (ok && ctx->version == 0) { + GGML_LOG_ERROR("%s: bad GGUF version: %" PRIu32 "\n", __func__, ctx->version); + ok = false; + } + + /* + * bit layout is different when reading non-native endian models. + * assuming that the GGUF version is 3, the non-native endian model + * would read it as 0x30000000. we can use the AND operation against + * the last 4 hexadecimal digits to check if the model is the same + * endianness as the host system. + */ + if (ok && (ctx->version & 0x0000FFFF) == 0x00000000) { + GGML_LOG_ERROR("%s: failed to load model: this GGUF file version %" PRIu32 " is extremely large, is there a mismatch between the host and model endianness?\n", __func__, ctx->version); + ok = false; + } + + if (ok && ctx->version == 1) { GGML_LOG_ERROR("%s: GGUFv1 is no longer supported, please use a more up-to-date version\n", __func__); ok = false; } - if (ctx->version > GGUF_VERSION) { + if (ok && ctx->version > GGUF_VERSION) { GGML_LOG_ERROR("%s: this GGUF file is version %" PRIu32 " but this software only supports up to version %d\n", __func__, ctx->version, GGUF_VERSION); ok = false; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 618f87180..ae30129b3 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -177,6 +177,9 @@ class Keys: EMBEDDING_LENGTH = "{arch}.convnext.embedding_length" BLOCK_COUNT = "{arch}.convnext.block_count" + class Classifier: + OUTPUT_LABELS = "{arch}.classifier.output_labels" + class Tokenizer: MODEL = "tokenizer.ggml.model" PRE = "tokenizer.ggml.pre" @@ -219,10 +222,13 @@ class Keys: TYPE = "adapter.type" LORA_ALPHA = "adapter.lora.alpha" - class ClipVision: + class Clip: PROJECTOR_TYPE = "clip.projector_type" HAS_VISION_ENCODER = "clip.has_vision_encoder" + HAS_AUDIO_ENCODER = "clip.has_audio_encoder" HAS_LLAVA_PROJECTOR = "clip.has_llava_projector" + + class ClipVision: IMAGE_SIZE = "clip.vision.image_size" PATCH_SIZE = "clip.vision.patch_size" EMBEDDING_LENGTH = "clip.vision.embedding_length" @@ -243,19 +249,33 @@ class Keys: class Projector: SCALE_FACTOR = "clip.vision.projector.scale_factor" + class ClipAudio: + NUM_MEL_BINS = "clip.audio.num_mel_bins" + EMBEDDING_LENGTH = "clip.audio.embedding_length" + FEED_FORWARD_LENGTH = "clip.audio.feed_forward_length" + PROJECTION_DIM = "clip.audio.projection_dim" + BLOCK_COUNT = "clip.audio.block_count" + + class Attention: + HEAD_COUNT = "clip.audio.attention.head_count" + LAYERNORM_EPS = "clip.audio.attention.layer_norm_epsilon" + + class Projector: + STACK_FACTOR = "clip.audio.projector.stack_factor" + # # recommended mapping of model tensor names for storage in gguf # class GGUFType: - MODEL = "model" - ADAPTER = "adapter" - CLIP_VISION = "clip-vision" + MODEL = "model" + ADAPTER = "adapter" + MMPROJ = "mmproj" # dummy, unused for now class MODEL_ARCH(IntEnum): - CLIP_VISION = auto() # dummy arch for clip.cpp + MMPROJ = auto() # dummy arch for clip.cpp LLAMA = auto() LLAMA4 = auto() DECI = auto() @@ -515,10 +535,28 @@ class MODEL_TENSOR(IntEnum): V_RESMPL_QUERY = auto() # minicpmv V_TOK_EMBD_IMG_BREAK = auto() # pixtral V_MM_PATCH_MERGER = auto() # mistral small 3.1 + # audio (mtmd) + A_ENC_EMBD_POS = auto() + A_ENC_CONV1D = auto() + A_PRE_NORM = auto() + A_POST_NORM = auto() + A_ENC_ATTN_Q = auto() + A_ENC_ATTN_K = auto() + A_ENC_ATTN_V = auto() + A_ENC_INPUT_NORM = auto() + A_ENC_OUTPUT = auto() + A_ENC_OUTPUT_NORM = auto() + A_ENC_FFN_UP = auto() + A_ENC_FFN_GATE = auto() + A_ENC_FFN_DOWN = auto() + A_MMPROJ = auto() + A_MMPROJ_FC = auto() + A_MM_NORM_PRE = auto() + A_MM_NORM_MID = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { - MODEL_ARCH.CLIP_VISION: "clip", # dummy arch for clip.cpp + MODEL_ARCH.MMPROJ: "clip", # dummy arch for clip.cpp MODEL_ARCH.LLAMA: "llama", MODEL_ARCH.LLAMA4: "llama4", MODEL_ARCH.DECI: "deci", @@ -778,10 +816,28 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query", MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral MODEL_TENSOR.V_MM_PATCH_MERGER: "mm.patch_merger", # mistral small 3.1 + # audio (mtmd) + MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd", + MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}", + MODEL_TENSOR.A_PRE_NORM: "a.pre_ln", + MODEL_TENSOR.A_POST_NORM: "a.post_ln", + MODEL_TENSOR.A_ENC_ATTN_Q: "a.blk.{bid}.attn_q", + MODEL_TENSOR.A_ENC_ATTN_K: "a.blk.{bid}.attn_k", + MODEL_TENSOR.A_ENC_ATTN_V: "a.blk.{bid}.attn_v", + MODEL_TENSOR.A_ENC_INPUT_NORM: "a.blk.{bid}.ln1", + MODEL_TENSOR.A_ENC_OUTPUT: "a.blk.{bid}.attn_out", + MODEL_TENSOR.A_ENC_OUTPUT_NORM: "a.blk.{bid}.ln2", + MODEL_TENSOR.A_ENC_FFN_UP: "a.blk.{bid}.ffn_up", + MODEL_TENSOR.A_ENC_FFN_GATE: "a.blk.{bid}.ffn_gate", + MODEL_TENSOR.A_ENC_FFN_DOWN: "a.blk.{bid}.ffn_down", + MODEL_TENSOR.A_MMPROJ: "mm.a.mlp.{bid}", + MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc", + MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre", + MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { - MODEL_ARCH.CLIP_VISION: [ + MODEL_ARCH.MMPROJ: [ MODEL_TENSOR.V_MMPROJ, MODEL_TENSOR.V_MMPROJ_FC, MODEL_TENSOR.V_MMPROJ_MLP, @@ -821,6 +877,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.V_RESMPL_QUERY, MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK, MODEL_TENSOR.V_MM_PATCH_MERGER, + # audio + MODEL_TENSOR.A_ENC_EMBD_POS, + MODEL_TENSOR.A_ENC_CONV1D, + MODEL_TENSOR.A_PRE_NORM, + MODEL_TENSOR.A_POST_NORM, + MODEL_TENSOR.A_ENC_ATTN_Q, + MODEL_TENSOR.A_ENC_ATTN_K, + MODEL_TENSOR.A_ENC_ATTN_V, + MODEL_TENSOR.A_ENC_INPUT_NORM, + MODEL_TENSOR.A_ENC_OUTPUT, + MODEL_TENSOR.A_ENC_OUTPUT_NORM, + MODEL_TENSOR.A_ENC_FFN_UP, + MODEL_TENSOR.A_ENC_FFN_GATE, + MODEL_TENSOR.A_ENC_FFN_DOWN, + MODEL_TENSOR.A_MMPROJ, + MODEL_TENSOR.A_MMPROJ_FC, + MODEL_TENSOR.A_MM_NORM_PRE, + MODEL_TENSOR.A_MM_NORM_MID, ], MODEL_ARCH.LLAMA: [ MODEL_TENSOR.TOKEN_EMBD, @@ -964,6 +1038,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.POS_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.ATTN_OUT_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2200,7 +2275,10 @@ class VisionProjectorType: LLAMA4 = "llama4" QWEN2VL = "qwen2vl_merger" QWEN25VL = "qwen2.5vl_merger" + ULTRAVOX = "ultravox" INTERNVL = "internvl" + QWEN2A = "qwen2a" # audio + QWEN25O = "qwen2.5o" # omni # Items here are (block size, type size) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index ff50d3de3..de6e45ae8 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -49,6 +49,7 @@ class TensorInfo: class GGUFValue: value: Any type: GGUFValueType + sub_type: GGUFValueType | None = None class WriterState(Enum): @@ -238,7 +239,7 @@ class GGUFWriter: for key, val in kv_data.items(): kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False) - kv_bytes += self._pack_val(val.value, val.type, add_vtype=True) + kv_bytes += self._pack_val(val.value, val.type, add_vtype=True, sub_type=val.sub_type) fout.write(kv_bytes) @@ -268,11 +269,11 @@ class GGUFWriter: fout.flush() self.state = WriterState.TI_DATA - def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None: + def add_key_value(self, key: str, val: Any, vtype: GGUFValueType, sub_type: GGUFValueType | None = None) -> None: if any(key in kv_data for kv_data in self.kv_data): raise ValueError(f'Duplicated key name {key!r}') - self.kv_data[0][key] = GGUFValue(value=val, type=vtype) + self.kv_data[0][key] = GGUFValue(value=val, type=vtype, sub_type=sub_type) def add_uint8(self, key: str, val: int) -> None: self.add_key_value(key,val, GGUFValueType.UINT8) @@ -896,7 +897,7 @@ class GGUFWriter: def add_remove_extra_whitespaces(self, value: bool) -> None: self.add_bool(Keys.Tokenizer.REMOVE_EXTRA_WS, value) - def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None: + def add_precompiled_charsmap(self, charsmap: bytes) -> None: self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap) def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None: @@ -936,12 +937,18 @@ class GGUFWriter: # for vision models + def add_clip_has_vision_encoder(self, value: bool) -> None: + self.add_bool(Keys.Clip.HAS_VISION_ENCODER, value) + + def add_clip_has_audio_encoder(self, value: bool) -> None: + self.add_bool(Keys.Clip.HAS_AUDIO_ENCODER, value) + + def add_clip_projector_type(self, value: str) -> None: + self.add_string(Keys.Clip.PROJECTOR_TYPE, value) + def add_vision_projection_dim(self, value: int) -> None: self.add_uint32(Keys.ClipVision.PROJECTION_DIM, value) - def add_vision_has_vision_encoder(self, value: bool) -> None: - self.add_bool(Keys.ClipVision.HAS_VISION_ENCODER, value) - def add_vision_patch_size(self, value: int) -> None: self.add_uint32(Keys.ClipVision.PATCH_SIZE, value) @@ -957,9 +964,6 @@ class GGUFWriter: def add_vision_head_count(self, value: int) -> None: self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT, value) - def add_vision_projector_type(self, value: str) -> None: - self.add_string(Keys.ClipVision.PROJECTOR_TYPE, value) - def add_vision_attention_layernorm_eps(self, value: float) -> None: self.add_float32(Keys.ClipVision.Attention.LAYERNORM_EPS, value) @@ -987,13 +991,39 @@ class GGUFWriter: def add_vision_n_wa_pattern(self, value: int) -> None: self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value) + # audio models + + def add_audio_projection_dim(self, value: int) -> None: + self.add_uint32(Keys.ClipAudio.PROJECTION_DIM, value) + + def add_audio_embedding_length(self, value: int) -> None: + self.add_uint32(Keys.ClipAudio.EMBEDDING_LENGTH, value) + + def add_audio_feed_forward_length(self, value: int) -> None: + self.add_uint32(Keys.ClipAudio.FEED_FORWARD_LENGTH, value) + + def add_audio_block_count(self, value: int) -> None: + self.add_uint32(Keys.ClipAudio.BLOCK_COUNT, value) + + def add_audio_head_count(self, value: int) -> None: + self.add_uint32(Keys.ClipAudio.Attention.HEAD_COUNT, value) + + def add_audio_attention_layernorm_eps(self, value: float) -> None: + self.add_float32(Keys.ClipAudio.Attention.LAYERNORM_EPS, value) + + def add_audio_num_mel_bins(self, value: int) -> None: + self.add_uint32(Keys.ClipAudio.NUM_MEL_BINS, value) + + def add_audio_stack_factor(self, value: int) -> None: + self.add_uint32(Keys.ClipAudio.Projector.STACK_FACTOR, value) + def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes: pack_prefix = '' if not skip_pack_prefix: pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>' return struct.pack(f'{pack_prefix}{fmt}', value) - def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool) -> bytes: + def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool, sub_type: GGUFValueType | None = None) -> bytes: kv_data = bytearray() if add_vtype: @@ -1014,7 +1044,9 @@ class GGUFWriter: if len(val) == 0: raise ValueError("Invalid GGUF metadata array. Empty array") - if isinstance(val, bytes): + if sub_type is not None: + ltype = sub_type + elif isinstance(val, bytes): ltype = GGUFValueType.UINT8 else: ltype = GGUFValueType.get_type(val[0]) diff --git a/gguf-py/gguf/scripts/gguf_editor_gui.py b/gguf-py/gguf/scripts/gguf_editor_gui.py index 3d38b5cba..05f4db0f8 100755 --- a/gguf-py/gguf/scripts/gguf_editor_gui.py +++ b/gguf-py/gguf/scripts/gguf_editor_gui.py @@ -1521,19 +1521,21 @@ class GGUFEditorWindow(QMainWindow): continue # Apply changes if any + sub_type = None if field.name in self.metadata_changes: value_type, value = self.metadata_changes[field.name] if value_type == GGUFValueType.ARRAY: # Handle array values - element_type, array_values = value - writer.add_array(field.name, array_values) - else: - writer.add_key_value(field.name, value, value_type) + sub_type, value = value else: # Copy original value value = field.contents() - if value is not None and field.types: - writer.add_key_value(field.name, value, field.types[0]) + value_type = field.types[0] + if value_type == GGUFValueType.ARRAY: + sub_type = field.types[-1] + + if value is not None: + writer.add_key_value(field.name, value, value_type, sub_type=sub_type) # Add new metadata for key, (value_type, value) in self.metadata_changes.items(): @@ -1541,7 +1543,12 @@ class GGUFEditorWindow(QMainWindow): if self.reader.get_field(key) is not None: continue - writer.add_key_value(key, value, value_type) + sub_type = None + if value_type == GGUFValueType.ARRAY: + # Handle array values + sub_type, value = value + + writer.add_key_value(key, value, value_type, sub_type=sub_type) # Add tensors (including data) for tensor in self.reader.tensors: diff --git a/gguf-py/gguf/scripts/gguf_new_metadata.py b/gguf-py/gguf/scripts/gguf_new_metadata.py index 7aff6c925..63f230034 100755 --- a/gguf-py/gguf/scripts/gguf_new_metadata.py +++ b/gguf-py/gguf/scripts/gguf_new_metadata.py @@ -24,6 +24,7 @@ class MetadataDetails(NamedTuple): type: gguf.GGUFValueType value: Any description: str = '' + sub_type: gguf.GGUFValueType | None = None def get_field_data(reader: gguf.GGUFReader, key: str) -> Any: @@ -57,7 +58,9 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new logger.debug(f'Removing {field.name}') continue - old_val = MetadataDetails(field.types[0], field.contents()) + val_type = field.types[0] + sub_type = field.types[-1] if val_type == gguf.GGUFValueType.ARRAY else None + old_val = MetadataDetails(val_type, field.contents(), sub_type=sub_type) val = new_metadata.get(field.name, old_val) if field.name in new_metadata: @@ -67,7 +70,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new logger.debug(f'Copying {field.name}') if val.value is not None: - writer.add_key_value(field.name, val.value, val.type) + writer.add_key_value(field.name, val.value, val.type, sub_type=sub_type if val.sub_type is None else val.sub_type) if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata: logger.debug('Adding chat template(s)') diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index b6eb770d8..93dd1d802 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -157,7 +157,7 @@ class TensorNameMap: "h.{bid}.attn.c_attn", # gpt2 "transformer.h.{bid}.mixer.Wqkv", # phi2 "encoder.layers.{bid}.attn.Wqkv", # nomic-bert - "encoder.layers.{bid}.mixer.Wqkv", # jina-bert-v3 + "encoder.layers.{bid}.mixer.Wqkv", # jina "model.layers.{bid}.self_attn.qkv_proj", # phi3 "encoder.layers.{bid}.self_attention.query_key_value", # chatglm "transformer.layers.{bid}.attn.qkv_proj", # openelm @@ -169,6 +169,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom "layers.{bid}.attention.wq", # llama-pth "encoder.layer.{bid}.attention.self.query", # bert + "transformer.layer.{bid}.attention.q_lin", # distillbert "transformer.h.{bid}.attn.q_proj", # gpt-j "model.layers.layers.{bid}.self_attn.q_proj", # plamo "model.layers.{bid}.attention.wq", # internlm2 @@ -183,6 +184,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom "layers.{bid}.attention.wk", # llama-pth "encoder.layer.{bid}.attention.self.key", # bert + "transformer.layer.{bid}.attention.k_lin", # distillbert "transformer.h.{bid}.attn.k_proj", # gpt-j "transformer.h.{bid}.attn.k", # refact "model.layers.layers.{bid}.self_attn.k_proj", # plamo @@ -197,6 +199,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 phimoe "layers.{bid}.attention.wv", # llama-pth "encoder.layer.{bid}.attention.self.value", # bert + "transformer.layer.{bid}.attention.v_lin", # distillbert "transformer.h.{bid}.attn.v_proj", # gpt-j "transformer.h.{bid}.attn.v", # refact "model.layers.layers.{bid}.self_attn.v_proj", # plamo @@ -217,6 +220,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.linear_attn", # deci "layers.{bid}.attention.wo", # llama-pth "encoder.layer.{bid}.attention.output.dense", # bert + "transformer.layer.{bid}.attention.out_lin", # distillbert "transformer.h.{bid}.attn.out_proj", # gpt-j "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon "model.layers.{bid}.self_attn.dense", # persimmon @@ -225,7 +229,7 @@ class TensorNameMap: "model.layers.layers.{bid}.self_attn.o_proj", # plamo "model.layers.{bid}.attention.wo", # internlm2 "encoder.layers.{bid}.attn.out_proj", # nomic-bert - "encoder.layers.{bid}.mixer.out_proj", # jina-bert-v3 + "encoder.layers.{bid}.mixer.out_proj", # jina "transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok "transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx "encoder.layers.{bid}.self_attention.dense", # chatglm @@ -237,6 +241,7 @@ class TensorNameMap: # Attention output norm MODEL_TENSOR.ATTN_OUT_NORM: ( "encoder.layer.{bid}.attention.output.LayerNorm", # bert + "transformer.layer.{bid}.sa_layer_norm", # distillbert "encoder.layers.{bid}.norm1", # nomic-bert "transformer.decoder_layer.{bid}.rms_norm_1", # Grok "transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx @@ -313,6 +318,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron olmo2 "layers.{bid}.feed_forward.w3", # llama-pth "encoder.layer.{bid}.intermediate.dense", # bert + "transformer.layer.{bid}.ffn.lin1", # distillbert "transformer.h.{bid}.mlp.fc_in", # gpt-j "transformer.h.{bid}.mlp.linear_3", # refact "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon @@ -396,6 +402,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.down_proj", # llama-hf nemotron olmo2 "layers.{bid}.feed_forward.w2", # llama-pth "encoder.layer.{bid}.output.dense", # bert + "transformer.layer.{bid}.ffn.lin2", # distillbert "transformer.h.{bid}.mlp.fc_out", # gpt-j "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon "model.layers.{bid}.mlp.dense_4h_to_h", # persimmon @@ -457,6 +464,7 @@ class TensorNameMap: MODEL_TENSOR.LAYER_OUT_NORM: ( "encoder.layer.{bid}.output.LayerNorm", # bert + "transformer.layer.{bid}.output_layer_norm", # distillbert "encoder.layers.{bid}.norm2", # nomic-bert "transformer.decoder_layer.{bid}.rms_norm_3", # Grok "encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2 @@ -827,6 +835,7 @@ class TensorNameMap: MODEL_TENSOR.CLS: ( "classifier", # jina "classifier.dense", # roberta + "pre_classifier", # distillbert ), MODEL_TENSOR.CLS_OUT: ( @@ -904,7 +913,6 @@ class TensorNameMap: MODEL_TENSOR.V_MMPROJ_FC: ( "model.connector.modality_projection.proj", # SmolVLM - "multi_modal_projector.linear_1", # llama 4 ), MODEL_TENSOR.V_MMPROJ_MLP: ( @@ -1112,6 +1120,77 @@ class TensorNameMap: MODEL_TENSOR.V_MM_PATCH_MERGER: ( "multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1 ), + + # audio (mtmd) + + MODEL_TENSOR.A_ENC_EMBD_POS: ( + "audio_tower.embed_positions", # ultravox + ), + + MODEL_TENSOR.A_ENC_CONV1D: ( + "audio_tower.conv{bid}", # ultravox + ), + + MODEL_TENSOR.A_PRE_NORM: (), + + MODEL_TENSOR.A_POST_NORM: ( + "audio_tower.layer_norm", # ultravox + "audio_tower.ln_post", # qwen2omni + ), + + MODEL_TENSOR.A_ENC_ATTN_Q: ( + "audio_tower.layers.{bid}.self_attn.q_proj", # ultravox + ), + + MODEL_TENSOR.A_ENC_ATTN_K: ( + "audio_tower.layers.{bid}.self_attn.k_proj", # ultravox + ), + + MODEL_TENSOR.A_ENC_ATTN_V: ( + "audio_tower.layers.{bid}.self_attn.v_proj", # ultravox + ), + + MODEL_TENSOR.A_ENC_INPUT_NORM: ( + "audio_tower.layers.{bid}.self_attn_layer_norm", # ultravox + ), + + MODEL_TENSOR.A_ENC_OUTPUT: ( + "audio_tower.layers.{bid}.self_attn.out_proj", # ultravox + ), + + MODEL_TENSOR.A_ENC_OUTPUT_NORM: ( + "audio_tower.layers.{bid}.final_layer_norm", # ultravox + ), + + MODEL_TENSOR.A_ENC_FFN_UP: ( + "audio_tower.layers.{bid}.fc1", # ultravox + ), + + MODEL_TENSOR.A_ENC_FFN_GATE: (), + + MODEL_TENSOR.A_ENC_FFN_DOWN: ( + "audio_tower.layers.{bid}.fc2", # ultravox + ), + + # note: some tensors below has "audio." pseudo-prefix, to prevent conflicts with vision tensors + # this prefix is added in the conversion code in modify_tensors() + + MODEL_TENSOR.A_MMPROJ: ( + "audio.multi_modal_projector.linear_{bid}", # ultravox + ), + + MODEL_TENSOR.A_MMPROJ_FC: ( + "audio.multi_modal_projector.linear", # qwen2audio + "audio_tower.proj", # qwen2omni + ), + + MODEL_TENSOR.A_MM_NORM_PRE: ( + "audio.multi_modal_projector.ln_pre", # ultravox + ), + + MODEL_TENSOR.A_MM_NORM_MID: ( + "audio.multi_modal_projector.ln_mid", # ultravox + ), } # architecture-specific block mappings diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index e5251aef8..00adcbc93 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -231,7 +231,7 @@ class SafetensorRemote: response.raise_for_status() # Get raw byte data - return response.content[:size] + return response.content[slice(size if size > -1 else None)] @classmethod def check_file_exist(cls, url: str) -> bool: diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml index bb9b86ace..f11351cba 100644 --- a/gguf-py/pyproject.toml +++ b/gguf-py/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "gguf" -version = "0.16.3" +version = "0.17.0" description = "Read and write ML models in GGUF for GGML" authors = ["GGML "] packages = [ diff --git a/include/llama.h b/include/llama.h index 52cd7a5a0..da0f652cf 100644 --- a/include/llama.h +++ b/include/llama.h @@ -259,9 +259,9 @@ extern "C" { llama_token * token; float * embd; llama_pos * pos; - int32_t * n_seq_id; - llama_seq_id ** seq_id; - int8_t * logits; // TODO: rename this to "output" + int32_t * n_seq_id; // TODO: remove, should belong to only 1 sequence + llama_seq_id ** seq_id; // TODO: become llama_seq_id * seq_id; + int8_t * logits; // TODO: rename this to "output" } llama_batch; enum llama_model_kv_override_type { @@ -366,6 +366,8 @@ extern "C" { bool no_perf; // measure performance timings bool op_offload; // offload host tensor operations to device bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) + // NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases + // ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573 }; // model quantization parameters @@ -471,6 +473,7 @@ extern "C" { LLAMA_API int64_t llama_time_us(void); LLAMA_API size_t llama_max_devices(void); + LLAMA_API size_t llama_max_parallel_sequences(void); LLAMA_API bool llama_supports_mmap (void); LLAMA_API bool llama_supports_mlock (void); @@ -501,6 +504,7 @@ extern "C" { LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model); // Get the model's RoPE frequency scaling factor LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); @@ -611,11 +615,11 @@ extern "C" { // Returns the number of tokens in the KV cache (slow, use only for debug) // If a KV cell has multiple sequences assigned to it, it will be counted multiple times DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx), - "Use llama_kv_self_seq_pos_max() instead"); + "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)"); // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx), - "Use llama_kv_self_seq_pos_max() instead"); + "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)"); // Clear the KV cache - both cell info is erased and KV data is zeroed LLAMA_API void llama_kv_self_clear( @@ -651,7 +655,6 @@ extern "C" { // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // If the KV cache is RoPEd, the KV data is updated accordingly: // - lazily on next llama_decode() - // - explicitly with llama_kv_self_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) LLAMA_API void llama_kv_self_seq_add( @@ -664,7 +667,6 @@ extern "C" { // Integer division of the positions by factor of `d > 1` // If the KV cache is RoPEd, the KV data is updated accordingly: // - lazily on next llama_decode() - // - explicitly with llama_kv_self_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) LLAMA_API void llama_kv_self_seq_div( @@ -676,12 +678,14 @@ extern "C" { // Returns the smallest position present in the KV cache for the specified sequence // This is typically non-zero only for SWA caches + // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache // Return -1 if the sequence is empty LLAMA_API llama_pos llama_kv_self_seq_pos_min( struct llama_context * ctx, llama_seq_id seq_id); // Returns the largest position present in the KV cache for the specified sequence + // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache // Return -1 if the sequence is empty LLAMA_API llama_pos llama_kv_self_seq_pos_max( struct llama_context * ctx, @@ -690,14 +694,15 @@ extern "C" { // Defragment the KV cache // This will be applied: // - lazily on next llama_decode() - // - explicitly with llama_kv_self_update() - LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx); + LLAMA_API DEPRECATED(void llama_kv_self_defrag(struct llama_context * ctx), + "simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'"); // Check if the context supports KV cache shifting LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx); // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) - LLAMA_API void llama_kv_self_update(struct llama_context * ctx); + LLAMA_API DEPRECATED(void llama_kv_self_update(struct llama_context * ctx), + "simply remove this call, updates are applied lazily on the next llama_decode()"); // // State / sessions diff --git a/models/ggml-vocab-bert-bge.gguf.inp b/models/ggml-vocab-bert-bge.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-bert-bge.gguf.inp +++ b/models/ggml-vocab-bert-bge.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-bert-bge.gguf.out b/models/ggml-vocab-bert-bge.gguf.out index a62566ce7..b1c49672f 100644 --- a/models/ggml-vocab-bert-bge.gguf.out +++ b/models/ggml-vocab-bert-bge.gguf.out @@ -1,5 +1,5 @@ 29464 2094 1018 1092 2706 - 11865 17875 + 9706 7959 2140 diff --git a/models/ggml-vocab-chameleon.gguf.inp b/models/ggml-vocab-chameleon.gguf.inp deleted file mode 100644 index 9baf7d77a..000000000 --- a/models/ggml-vocab-chameleon.gguf.inp +++ /dev/null @@ -1,112 +0,0 @@ -ied 4 ½ months -__ggml_vocab_test__ -Führer -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - - -__ggml_vocab_test__ - - - -__ggml_vocab_test__ - - - - -__ggml_vocab_test__ - - -__ggml_vocab_test__ -Hello world -__ggml_vocab_test__ - Hello world -__ggml_vocab_test__ -Hello World -__ggml_vocab_test__ - Hello World -__ggml_vocab_test__ - Hello World! -__ggml_vocab_test__ -Hello, world! -__ggml_vocab_test__ - Hello, world! -__ggml_vocab_test__ - this is 🦙.cpp -__ggml_vocab_test__ -w048 7tuijk dsdfhu -__ggml_vocab_test__ -нещо на Български -__ggml_vocab_test__ -កាន់តែពិសេសអាចខលចេញ -__ggml_vocab_test__ -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) -__ggml_vocab_test__ -Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello - Hello -__ggml_vocab_test__ - ( -__ggml_vocab_test__ - - = -__ggml_vocab_test__ -' era -__ggml_vocab_test__ -Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ -__ggml_vocab_test__ -!!!!!! -__ggml_vocab_test__ -3 -__ggml_vocab_test__ -33 -__ggml_vocab_test__ -333 -__ggml_vocab_test__ -3333 -__ggml_vocab_test__ -33333 -__ggml_vocab_test__ -333333 -__ggml_vocab_test__ -3333333 -__ggml_vocab_test__ -33333333 -__ggml_vocab_test__ -333333333 -__ggml_vocab_test__ -Cửa Việt -__ggml_vocab_test__ - discards -__ggml_vocab_test__ - - - - - - - - - - - -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL -__ggml_vocab_test__ diff --git a/models/ggml-vocab-chameleon.gguf.out b/models/ggml-vocab-chameleon.gguf.out deleted file mode 100644 index 7c5413fee..000000000 --- a/models/ggml-vocab-chameleon.gguf.out +++ /dev/null @@ -1,46 +0,0 @@ - 17245 16604 16403 16604 33583 18355 - 16421 51153 - - 16604 - 16650 - 16650 16604 - 16581 - 16582 - 16582 16582 - 16582 16582 16582 - 16581 16582 - 31596 17394 - 34926 17394 - 31596 18671 - 34926 18671 - 34926 18671 16384 - 31596 16395 17394 16384 - 34926 16395 17394 16384 - 16811 16704 20410 16483 16631 16397 52854 - 16470 16399 16403 16407 16604 16406 35764 38185 51595 22592 26639 - 29479 23955 17012 20103 25527 27670 17408 19005 21473 24774 - 54254 42231 48084 29409 16617 61889 29409 16608 21954 16628 21954 16499 58445 29409 16607 58445 21954 16479 42231 21954 16611 21954 16607 21954 16633 21954 16611 29409 16607 21954 16615 - 52351 16604 16391 25825 16392 23686 16498 39161 18885 16618 16488 30853 16604 16391 54124 17153 25134 16656 18476 26169 16895 16392 62193 16611 16604 16391 24664 17153 57169 16721 16872 17073 17304 28729 16392 - 31596 - 34926 - 16650 31596 - 16650 34926 - 16696 31596 - 16696 31596 16582 16696 31596 - 16604 16391 - 16582 16604 16412 - 16390 22623 - 31596 16395 16712 16390 16828 16384 17674 16769 16732 23686 16607 16604 16414 24427 16623 41809 16495 28999 36469 45292 30197 16400 16402 16400 16403 16400 16404 16400 43969 65211 16636 - 16384 16384 16384 16384 16384 16384 - 16402 - 16402 16402 - 16402 16402 16402 - 16402 16402 16402 16402 - 16402 16402 16402 16402 16402 - 16402 16402 16402 16402 16402 16402 - 16402 16402 16402 16402 16402 16402 16402 - 16402 16402 16402 16402 16402 16402 16402 16402 - 16402 16402 16402 16402 16402 16402 16402 16402 16402 - 16418 19038 16639 16448 24315 33727 16467 - 18765 17981 - 16582 16604 16582 16582 16604 16582 16582 16582 16604 16581 16604 16581 16581 16604 16581 16582 16650 16582 16650 16604 16582 16696 16582 16696 16604 16582 52351 16604 16391 25825 16392 23686 16498 39161 18885 16618 16488 30853 16604 16391 54124 17153 25134 16656 18476 26169 16895 16392 62193 16611 20410 16483 16631 18885 16483 16631 16604 16402 16604 16402 16402 16604 16402 16402 16402 16604 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16402 16402 16604 16402 16397 16402 16604 16402 16397 16397 16402 16604 16402 16397 16397 16397 16402 16604 54254 42231 48084 29409 16617 61889 29409 16608 21954 16628 21954 16499 58445 29409 16607 58445 21954 16479 42231 21954 16611 27683 16607 16604 16414 24427 16623 41809 16495 28999 36469 45292 30197 16400 16402 16400 16403 16400 16404 16400 43969 65211 16636 16604 16396 16396 16396 16396 16396 16396 16412 16412 16412 16412 16412 16412 16412 27268 23955 17012 20103 25527 27670 17408 19005 21473 24774 16604 16390 16390 16390 16390 16390 16390 16447 16447 16447 16447 16447 16447 16447 16385 16385 16385 16385 16397 16397 16397 16397 16397 16397 16384 16384 16384 16384 16384 16384 16414 16414 16414 16414 16414 16414 16687 16390 16690 16992 16604 16390 61797 16733 16390 16466 16986 16395 16604 16390 17879 16732 17811 16414 16604 16390 16428 16804 17811 16687 16390 16683 17190 16728 16395 16604 16390 16419 16732 16945 16991 25251 16414 17119 16390 38127 16641 16390 16459 16427 diff --git a/models/ggml-vocab-command-r.gguf.inp b/models/ggml-vocab-command-r.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-command-r.gguf.inp +++ b/models/ggml-vocab-command-r.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-command-r.gguf.out b/models/ggml-vocab-command-r.gguf.out index 3f6b41888..0e3af72eb 100644 --- a/models/ggml-vocab-command-r.gguf.out +++ b/models/ggml-vocab-command-r.gguf.out @@ -1,5 +1,5 @@ 2536 228 27 228 22957 6983 - 45 193433 + 90711 87 20910 228 1667 diff --git a/models/ggml-vocab-deepseek-coder.gguf.inp b/models/ggml-vocab-deepseek-coder.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-deepseek-coder.gguf.inp +++ b/models/ggml-vocab-deepseek-coder.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-deepseek-coder.gguf.out b/models/ggml-vocab-deepseek-coder.gguf.out index 52c4111a1..ef6bc5b8a 100644 --- a/models/ggml-vocab-deepseek-coder.gguf.out +++ b/models/ggml-vocab-deepseek-coder.gguf.out @@ -1,5 +1,5 @@ 1050 207 19 207 19192 4217 - 37 32009 71 6247 + 125 213 26862 282 207 243 diff --git a/models/ggml-vocab-deepseek-llm.gguf.inp b/models/ggml-vocab-deepseek-llm.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-deepseek-llm.gguf.inp +++ b/models/ggml-vocab-deepseek-llm.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-deepseek-llm.gguf.out b/models/ggml-vocab-deepseek-llm.gguf.out index 0191b7a11..f9d49c9af 100644 --- a/models/ggml-vocab-deepseek-llm.gguf.out +++ b/models/ggml-vocab-deepseek-llm.gguf.out @@ -1,5 +1,5 @@ 1052 207 19 207 19109 4223 - 37 100014 71 6245 + 82077 26723 282 207 243 diff --git a/models/ggml-vocab-deepseek-r1-qwen.gguf.inp b/models/ggml-vocab-deepseek-r1-qwen.gguf.inp deleted file mode 100644 index 9baf7d77a..000000000 --- a/models/ggml-vocab-deepseek-r1-qwen.gguf.inp +++ /dev/null @@ -1,112 +0,0 @@ -ied 4 ½ months -__ggml_vocab_test__ -Führer -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - - -__ggml_vocab_test__ - - - -__ggml_vocab_test__ - - - - -__ggml_vocab_test__ - - -__ggml_vocab_test__ -Hello world -__ggml_vocab_test__ - Hello world -__ggml_vocab_test__ -Hello World -__ggml_vocab_test__ - Hello World -__ggml_vocab_test__ - Hello World! -__ggml_vocab_test__ -Hello, world! -__ggml_vocab_test__ - Hello, world! -__ggml_vocab_test__ - this is 🦙.cpp -__ggml_vocab_test__ -w048 7tuijk dsdfhu -__ggml_vocab_test__ -нещо на Български -__ggml_vocab_test__ -កាន់តែពិសេសអាចខលចេញ -__ggml_vocab_test__ -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) -__ggml_vocab_test__ -Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello - Hello -__ggml_vocab_test__ - ( -__ggml_vocab_test__ - - = -__ggml_vocab_test__ -' era -__ggml_vocab_test__ -Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ -__ggml_vocab_test__ -!!!!!! -__ggml_vocab_test__ -3 -__ggml_vocab_test__ -33 -__ggml_vocab_test__ -333 -__ggml_vocab_test__ -3333 -__ggml_vocab_test__ -33333 -__ggml_vocab_test__ -333333 -__ggml_vocab_test__ -3333333 -__ggml_vocab_test__ -33333333 -__ggml_vocab_test__ -333333333 -__ggml_vocab_test__ -Cửa Việt -__ggml_vocab_test__ - discards -__ggml_vocab_test__ - - - - - - - - - - - -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL -__ggml_vocab_test__ diff --git a/models/ggml-vocab-deepseek-r1-qwen.gguf.out b/models/ggml-vocab-deepseek-r1-qwen.gguf.out deleted file mode 100644 index 18b4b45cd..000000000 --- a/models/ggml-vocab-deepseek-r1-qwen.gguf.out +++ /dev/null @@ -1,46 +0,0 @@ - 1122 220 19 220 26062 3951 - 37 50753 261 - - 220 - 256 - 262 - 197 - 198 - 271 - 1406 - 1572 - 9707 1879 - 21927 1879 - 9707 4337 - 21927 4337 - 21927 4337 0 - 9707 11 1879 0 - 21927 11 1879 0 - 419 374 11162 99 247 13 10821 - 86 15 19 23 220 22 83 1963 41808 11472 2940 16739 - 78762 14144 1456 13073 63471 33594 3038 133178 79012 - 146394 97529 241 44258 233 146568 44258 224 147603 20879 115 146280 44258 223 146280 147272 97529 227 147805 148301 147270 44258 223 146848 - 145836 320 8252 8 26525 114 378 235 149921 30543 320 35673 99066 97534 8 25521 227 320 3243 42365 429 702 1181 1828 3950 8 - 9707 - 21927 - 220 21927 - 256 21927 - 262 21927 - 262 21927 198 262 21927 - 320 - 198 284 - 6 11385 - 9707 11 379 64848 0 2585 525 498 26525 223 937 104100 18493 22377 99257 16 18 16 19 16 20 16 35727 21216 - 17085 2928 - 18 - 18 18 - 18 18 18 - 18 18 18 18 - 18 18 18 18 18 - 18 18 18 18 18 18 - 18 18 18 18 18 18 18 - 18 18 18 18 18 18 18 18 - 18 18 18 18 18 18 18 18 18 - 34 90063 128324 - 2560 2347 - 198 4710 14731 65497 7847 1572 2303 78672 10947 145836 320 8252 8 26525 114 378 235 149921 30543 320 35673 99066 97534 8 25521 227 11162 99 247 149955 220 18 220 18 18 220 18 18 18 220 18 18 18 18 220 18 18 18 18 18 220 18 18 18 18 18 18 220 18 18 18 18 18 18 18 220 18 18 18 18 18 18 18 18 220 18 13 18 220 18 496 18 220 18 1112 18 220 146394 97529 241 44258 233 146568 44258 224 147603 20879 115 146280 44258 223 146280 147272 97529 227 144534 937 104100 18493 22377 99257 16 18 16 19 16 20 16 35727 21216 55460 53237 18658 14144 1456 13073 63471 33594 3038 133178 79012 3355 4605 4605 13874 13874 73594 3014 3014 28149 17085 2928 26610 7646 358 3003 1012 364 83 813 566 594 1052 11 364 787 498 2704 30 364 44 537 2704 358 3278 1281 432 11 364 35 498 1075 1045 15243 30 1205 6 42612 264 63866 43 diff --git a/models/ggml-vocab-falcon.gguf.inp b/models/ggml-vocab-falcon.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-falcon.gguf.inp +++ b/models/ggml-vocab-falcon.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-falcon.gguf.out b/models/ggml-vocab-falcon.gguf.out index 64a48d97f..6319de60e 100644 --- a/models/ggml-vocab-falcon.gguf.out +++ b/models/ggml-vocab-falcon.gguf.out @@ -1,5 +1,5 @@ 878 204 31 3068 133 2137 - 28611 132 30042 + 34502 18614 286 204 258 diff --git a/models/ggml-vocab-gpt-2.gguf.inp b/models/ggml-vocab-gpt-2.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-gpt-2.gguf.inp +++ b/models/ggml-vocab-gpt-2.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-gpt-2.gguf.out b/models/ggml-vocab-gpt-2.gguf.out index 17a13bdfc..6464ded3d 100644 --- a/models/ggml-vocab-gpt-2.gguf.out +++ b/models/ggml-vocab-gpt-2.gguf.out @@ -1,5 +1,5 @@ 798 604 25208 1933 - 37 9116 71 11751 + 127 226 79 69 417 220 220 220 diff --git a/models/ggml-vocab-gpt-4o.gguf.inp b/models/ggml-vocab-gpt-4o.gguf.inp deleted file mode 100644 index 9baf7d77a..000000000 --- a/models/ggml-vocab-gpt-4o.gguf.inp +++ /dev/null @@ -1,112 +0,0 @@ -ied 4 ½ months -__ggml_vocab_test__ -Führer -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - - -__ggml_vocab_test__ - - - -__ggml_vocab_test__ - - - - -__ggml_vocab_test__ - - -__ggml_vocab_test__ -Hello world -__ggml_vocab_test__ - Hello world -__ggml_vocab_test__ -Hello World -__ggml_vocab_test__ - Hello World -__ggml_vocab_test__ - Hello World! -__ggml_vocab_test__ -Hello, world! -__ggml_vocab_test__ - Hello, world! -__ggml_vocab_test__ - this is 🦙.cpp -__ggml_vocab_test__ -w048 7tuijk dsdfhu -__ggml_vocab_test__ -нещо на Български -__ggml_vocab_test__ -កាន់តែពិសេសអាចខលចេញ -__ggml_vocab_test__ -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) -__ggml_vocab_test__ -Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello - Hello -__ggml_vocab_test__ - ( -__ggml_vocab_test__ - - = -__ggml_vocab_test__ -' era -__ggml_vocab_test__ -Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ -__ggml_vocab_test__ -!!!!!! -__ggml_vocab_test__ -3 -__ggml_vocab_test__ -33 -__ggml_vocab_test__ -333 -__ggml_vocab_test__ -3333 -__ggml_vocab_test__ -33333 -__ggml_vocab_test__ -333333 -__ggml_vocab_test__ -3333333 -__ggml_vocab_test__ -33333333 -__ggml_vocab_test__ -333333333 -__ggml_vocab_test__ -Cửa Việt -__ggml_vocab_test__ - discards -__ggml_vocab_test__ - - - - - - - - - - - -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL -__ggml_vocab_test__ diff --git a/models/ggml-vocab-gpt-4o.gguf.out b/models/ggml-vocab-gpt-4o.gguf.out deleted file mode 100644 index 478df726f..000000000 --- a/models/ggml-vocab-gpt-4o.gguf.out +++ /dev/null @@ -1,46 +0,0 @@ - 1165 220 19 220 27124 5503 - 37 19194 259 - - 220 - 256 - 271 - 197 - 198 - 279 - 2499 - 2775 - 13225 2375 - 32949 2375 - 13225 5922 - 32949 5922 - 32949 5922 0 - 13225 11 2375 0 - 32949 11 2375 0 - 495 382 9552 99 247 13 17159 - 86 45404 220 22 10191 2852 22924 4750 6916 - 3907 53641 1235 185386 8118 - 11400 107516 15867 20804 22851 134178 77431 32010 104312 37984 16329 27751 89335 - 112927 222 350 14559 8 22861 114 2524 64364 104 15148 350 76466 166700 121942 780 8 91349 350 7393 74471 484 853 1617 2316 6602 8 - 13225 - 32949 - 220 32949 - 256 32949 - 271 32949 - 271 32949 198 271 32949 - 350 - 198 314 - 6 6837 - 13225 11 342 70653 0 3253 553 481 22861 223 1423 7522 18165 2178 34058 22369 16412 32999 16 867 8208 - 147475 - 18 - 2546 - 15517 - 15517 18 - 15517 2546 - 15517 15517 - 15517 15517 18 - 15517 15517 2546 - 15517 15517 15517 - 34 60213 53904 - 2960 3098 - 126470 25980 160432 16609 2775 4066 172261 19432 112927 222 350 14559 8 22861 114 2524 64364 104 15148 350 76466 166700 121942 780 8 91349 9552 99 247 4103 99 247 220 18 220 2546 220 15517 220 15517 18 220 15517 2546 220 15517 15517 220 15517 15517 18 220 15517 15517 2546 220 18 13 18 220 18 485 18 220 18 1008 18 44735 107516 15867 20804 22851 134178 77431 32010 104312 156437 1423 7522 18165 2178 34058 22369 16412 32999 16 867 8208 105024 106657 1967 53641 1235 185386 8118 22434 39336 26178 26178 168394 194663 27271 147475 25883 6961 9790 1339 461 83 1280 19016 1354 11 461 1099 481 3239 30 461 44 625 3239 17291 1520 480 11 461 35 481 1299 1236 17966 30 1416 6 27493 261 54602 43 diff --git a/models/ggml-vocab-llama-bpe.gguf.inp b/models/ggml-vocab-llama-bpe.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-llama-bpe.gguf.inp +++ b/models/ggml-vocab-llama-bpe.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-llama-bpe.gguf.out b/models/ggml-vocab-llama-bpe.gguf.out index 4b35cf93f..a77376625 100644 --- a/models/ggml-vocab-llama-bpe.gguf.out +++ b/models/ggml-vocab-llama-bpe.gguf.out @@ -1,5 +1,5 @@ 1142 220 19 220 27154 4038 - 37 51853 261 + 88075 16276 301 220 256 diff --git a/models/ggml-vocab-llama-spm.gguf.inp b/models/ggml-vocab-llama-spm.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-llama-spm.gguf.inp +++ b/models/ggml-vocab-llama-spm.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-llama-spm.gguf.out b/models/ggml-vocab-llama-spm.gguf.out index 93aacf8ba..2a71a6ef8 100644 --- a/models/ggml-vocab-llama-spm.gguf.out +++ b/models/ggml-vocab-llama-spm.gguf.out @@ -1,5 +1,5 @@ 474 287 29871 29946 29871 30226 7378 - 383 4000 261 + 11585 7810 295 259 1678 diff --git a/models/ggml-vocab-llama4.gguf.inp b/models/ggml-vocab-llama4.gguf.inp deleted file mode 100644 index 9baf7d77a..000000000 --- a/models/ggml-vocab-llama4.gguf.inp +++ /dev/null @@ -1,112 +0,0 @@ -ied 4 ½ months -__ggml_vocab_test__ -Führer -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - - -__ggml_vocab_test__ - - - -__ggml_vocab_test__ - - - - -__ggml_vocab_test__ - - -__ggml_vocab_test__ -Hello world -__ggml_vocab_test__ - Hello world -__ggml_vocab_test__ -Hello World -__ggml_vocab_test__ - Hello World -__ggml_vocab_test__ - Hello World! -__ggml_vocab_test__ -Hello, world! -__ggml_vocab_test__ - Hello, world! -__ggml_vocab_test__ - this is 🦙.cpp -__ggml_vocab_test__ -w048 7tuijk dsdfhu -__ggml_vocab_test__ -нещо на Български -__ggml_vocab_test__ -កាន់តែពិសេសអាចខលចេញ -__ggml_vocab_test__ -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) -__ggml_vocab_test__ -Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello - Hello -__ggml_vocab_test__ - ( -__ggml_vocab_test__ - - = -__ggml_vocab_test__ -' era -__ggml_vocab_test__ -Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ -__ggml_vocab_test__ -!!!!!! -__ggml_vocab_test__ -3 -__ggml_vocab_test__ -33 -__ggml_vocab_test__ -333 -__ggml_vocab_test__ -3333 -__ggml_vocab_test__ -33333 -__ggml_vocab_test__ -333333 -__ggml_vocab_test__ -3333333 -__ggml_vocab_test__ -33333333 -__ggml_vocab_test__ -333333333 -__ggml_vocab_test__ -Cửa Việt -__ggml_vocab_test__ - discards -__ggml_vocab_test__ - - - - - - - - - - - -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL -__ggml_vocab_test__ diff --git a/models/ggml-vocab-llama4.gguf.out b/models/ggml-vocab-llama4.gguf.out deleted file mode 100644 index 7ca46ce59..000000000 --- a/models/ggml-vocab-llama4.gguf.out +++ /dev/null @@ -1,46 +0,0 @@ - 1190 220 32 220 18215 7112 - 50 16800 258 - - 220 - 256 - 277 - 197 - 198 - 368 - 2946 - 3271 - 19873 3817 - 39715 3817 - 19873 7353 - 39715 7353 - 39715 7353 13 - 19873 24 3817 13 - 39715 24 3817 13 - 544 373 9522 112 247 26 36315 - 99 39923 220 35 9607 21498 21470 3679 9433 - 1595 7653 633 79829 34051 1636 - 8755 102595 115960 21125 148305 96819 102816 39048 14105 22528 160234 - 114590 222 330 14879 21 51358 127 12817 93293 117 24204 330 68239 881 120327 170428 21 89101 330 7384 88230 511 947 1492 3742 7233 21 - 19873 - 39715 - 220 39715 - 256 39715 - 277 39715 - 277 39715 198 277 39715 - 330 - 198 319 - 19 7359 - 19873 24 386 87799 13 2403 583 650 51358 223 1663 155736 1522 42056 7544 13336 28785 29 4412 20645 - 17931 4959 - 31 - 1922 - 12325 - 12325 31 - 12325 1922 - 12325 12325 - 12325 12325 31 - 12325 12325 1922 - 12325 12325 12325 - 47 19811 12077 - 3260 3579 - 198 7283 51499 191231 20192 3271 3322 9287 2143 17860 114590 222 330 14879 21 51358 127 12817 93293 117 24204 330 68239 881 120327 170428 21 89101 9522 112 247 172394 247 220 31 220 1922 220 12325 220 12325 31 220 12325 1922 220 12325 12325 220 12325 12325 31 220 12325 12325 1922 220 31 26 31 220 31 396 31 220 31 1043 31 117131 102595 115960 21125 148305 96819 102816 80883 223 1663 155736 1522 42056 7544 13336 28785 29 4412 20645 79745 150278 117079 633 79829 34051 1636 25611 41990 109428 1488 91054 24072 17931 4959 29795 9296 16517 1806 481 96 1386 36633 1609 24 481 1109 650 5074 43 481 57 702 5074 27088 2170 536 24 481 48 650 1933 1696 30262 43 1665 19 32818 262 27236 56 diff --git a/models/ggml-vocab-mpt.gguf.inp b/models/ggml-vocab-mpt.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-mpt.gguf.inp +++ b/models/ggml-vocab-mpt.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-mpt.gguf.out b/models/ggml-vocab-mpt.gguf.out index 372c751bf..ca62669ad 100644 --- a/models/ggml-vocab-mpt.gguf.out +++ b/models/ggml-vocab-mpt.gguf.out @@ -1,5 +1,5 @@ 728 577 24142 2607 - 39 26288 6554 + 37515 18569 293 209 50276 diff --git a/models/ggml-vocab-nomic-bert-moe.gguf b/models/ggml-vocab-nomic-bert-moe.gguf new file mode 100644 index 000000000..b6f4d9441 Binary files /dev/null and b/models/ggml-vocab-nomic-bert-moe.gguf differ diff --git a/models/ggml-vocab-phi-3.gguf.inp b/models/ggml-vocab-phi-3.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-phi-3.gguf.inp +++ b/models/ggml-vocab-phi-3.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-phi-3.gguf.out b/models/ggml-vocab-phi-3.gguf.out index 93aacf8ba..2a71a6ef8 100644 --- a/models/ggml-vocab-phi-3.gguf.out +++ b/models/ggml-vocab-phi-3.gguf.out @@ -1,5 +1,5 @@ 474 287 29871 29946 29871 30226 7378 - 383 4000 261 + 11585 7810 295 259 1678 diff --git a/models/ggml-vocab-pixtral.gguf.inp b/models/ggml-vocab-pixtral.gguf.inp deleted file mode 100644 index 9baf7d77a..000000000 --- a/models/ggml-vocab-pixtral.gguf.inp +++ /dev/null @@ -1,112 +0,0 @@ -ied 4 ½ months -__ggml_vocab_test__ -Führer -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - - -__ggml_vocab_test__ - - - -__ggml_vocab_test__ - - - - -__ggml_vocab_test__ - - -__ggml_vocab_test__ -Hello world -__ggml_vocab_test__ - Hello world -__ggml_vocab_test__ -Hello World -__ggml_vocab_test__ - Hello World -__ggml_vocab_test__ - Hello World! -__ggml_vocab_test__ -Hello, world! -__ggml_vocab_test__ - Hello, world! -__ggml_vocab_test__ - this is 🦙.cpp -__ggml_vocab_test__ -w048 7tuijk dsdfhu -__ggml_vocab_test__ -нещо на Български -__ggml_vocab_test__ -កាន់តែពិសេសអាចខលចេញ -__ggml_vocab_test__ -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) -__ggml_vocab_test__ -Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello - Hello -__ggml_vocab_test__ - ( -__ggml_vocab_test__ - - = -__ggml_vocab_test__ -' era -__ggml_vocab_test__ -Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ -__ggml_vocab_test__ -!!!!!! -__ggml_vocab_test__ -3 -__ggml_vocab_test__ -33 -__ggml_vocab_test__ -333 -__ggml_vocab_test__ -3333 -__ggml_vocab_test__ -33333 -__ggml_vocab_test__ -333333 -__ggml_vocab_test__ -3333333 -__ggml_vocab_test__ -33333333 -__ggml_vocab_test__ -333333333 -__ggml_vocab_test__ -Cửa Việt -__ggml_vocab_test__ - discards -__ggml_vocab_test__ - - - - - - - - - - - -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL -__ggml_vocab_test__ diff --git a/models/ggml-vocab-pixtral.gguf.out b/models/ggml-vocab-pixtral.gguf.out deleted file mode 100644 index 53309d1bc..000000000 --- a/models/ggml-vocab-pixtral.gguf.out +++ /dev/null @@ -1,46 +0,0 @@ - 2014 1032 1052 1032 28504 6972 - 1070 7088 1258 - - 1032 - 1256 - 1293 - 1009 - 1010 - 1267 - 4688 - 1009 1010 - 22177 4304 - 45383 4304 - 22177 5325 - 45383 5325 - 45383 5325 1033 - 22177 1044 4304 1033 - 45383 1044 4304 1033 - 1593 1395 119685 1166 1153 1046 51228 - 1119 1048 1052 1056 1032 1055 17391 23216 30203 7785 17279 - 3337 30757 1902 4200 63073 3671 - 1225 1158 1128 1225 1158 1182 1225 1158 1147 1225 1159 1139 1225 1158 1143 1225 1159 1130 1225 1158 1150 1225 1158 1183 1225 1158 1159 1225 21359 1225 1158 1159 1225 1158 1162 1225 1158 1182 1225 1158 1133 1225 1158 1129 1225 1158 1155 1225 1158 1133 1225 21359 1225 1158 1137 - 1240 1159 1154 1128 1319 13052 1041 119685 1152 1182 29568 1240 1159 1140 1171 1239 1184 1143 1319 88181 1873 3659 1275 56421 1621 1041 126241 1133 1319 11234 1873 26303 1455 1934 2246 3754 10835 1041 - 22177 - 45383 - 1032 45383 - 1256 45383 - 1293 45383 - 1293 45383 1010 1293 45383 - 1319 - 1010 1376 - 1039 4033 - 22177 1044 1404 48054 1033 3075 1584 1636 119685 1152 1129 3082 26060 2998 63614 82278 1049 1051 1049 1052 1049 1053 1049 6434 6749 - 7290 7290 7290 - 1051 - 1051 1051 - 1051 1051 1051 - 1051 1051 1051 1051 - 1051 1051 1051 1051 1051 - 1051 1051 1051 1051 1051 1051 - 1051 1051 1051 1051 1051 1051 1051 - 1051 1051 1051 1051 1051 1051 1051 1051 - 1051 1051 1051 1051 1051 1051 1051 1051 1051 - 1067 59503 28783 - 3724 4058 - 1010 1032 1267 1032 4688 1032 17152 1458 29356 1010 1256 1010 1293 1010 1260 1010 1652 1010 1240 1159 1154 1128 1319 13052 1041 119685 1152 1182 29568 1240 1159 1140 1171 1239 1184 1143 1319 88181 1873 3659 1275 56421 1621 1041 126241 1133 119685 1166 1153 1240 1159 1166 1153 1032 1051 1032 1051 1051 1032 1051 1051 1051 1032 1051 1051 1051 1051 1032 1051 1051 1051 1051 1051 1032 1051 1051 1051 1051 1051 1051 1032 1051 1051 1051 1051 1051 1051 1051 1032 1051 1051 1051 1051 1051 1051 1051 1051 1032 1051 1046 1051 1032 1051 1791 1051 1032 1051 2880 1051 71881 1158 1128 1225 1158 1182 1225 1158 1147 1225 1159 1139 1225 1158 1143 1225 1159 1130 1225 1158 1150 1225 1158 1183 1225 1158 1159 1225 21359 1225 1158 1159 1225 1158 1162 1225 1158 1182 1225 1158 1133 1240 1159 1152 1129 3082 26060 2998 63614 82278 1049 1051 1049 1052 1049 1053 1049 6434 6749 45577 1045 6626 43555 2843 30757 1902 4200 63073 3671 14931 20040 20040 1657 1657 1975 14135 14135 83923 7290 7290 7290 45509 45509 45509 1362 6483 2151 1576 1116 2189 1514 1681 2156 1044 1576 3609 1636 5257 1063 1576 1077 1605 5257 1362 7534 3180 1494 1044 1576 1068 1636 2479 2269 26883 1063 2837 1039 45654 1261 54297 1076 diff --git a/models/ggml-vocab-qwen2.gguf.inp b/models/ggml-vocab-qwen2.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-qwen2.gguf.inp +++ b/models/ggml-vocab-qwen2.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-qwen2.gguf.out b/models/ggml-vocab-qwen2.gguf.out index 18b4b45cd..595d59a44 100644 --- a/models/ggml-vocab-qwen2.gguf.out +++ b/models/ggml-vocab-qwen2.gguf.out @@ -1,5 +1,5 @@ 1122 220 19 220 26062 3951 - 37 50753 261 + 86975 15897 301 220 256 diff --git a/models/ggml-vocab-refact.gguf.inp b/models/ggml-vocab-refact.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-refact.gguf.inp +++ b/models/ggml-vocab-refact.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-refact.gguf.out b/models/ggml-vocab-refact.gguf.out index 63d8305c3..f13dda52c 100644 --- a/models/ggml-vocab-refact.gguf.out +++ b/models/ggml-vocab-refact.gguf.out @@ -1,5 +1,5 @@ 4833 225 38 225 143 140 17723 - 56 2006 3935 265 + 144 231 7132 342 225 261 diff --git a/models/ggml-vocab-roberta-bpe.gguf.inp b/models/ggml-vocab-roberta-bpe.gguf.inp deleted file mode 100644 index 9baf7d77a..000000000 --- a/models/ggml-vocab-roberta-bpe.gguf.inp +++ /dev/null @@ -1,112 +0,0 @@ -ied 4 ½ months -__ggml_vocab_test__ -Führer -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - - -__ggml_vocab_test__ - - - -__ggml_vocab_test__ - - - - -__ggml_vocab_test__ - - -__ggml_vocab_test__ -Hello world -__ggml_vocab_test__ - Hello world -__ggml_vocab_test__ -Hello World -__ggml_vocab_test__ - Hello World -__ggml_vocab_test__ - Hello World! -__ggml_vocab_test__ -Hello, world! -__ggml_vocab_test__ - Hello, world! -__ggml_vocab_test__ - this is 🦙.cpp -__ggml_vocab_test__ -w048 7tuijk dsdfhu -__ggml_vocab_test__ -нещо на Български -__ggml_vocab_test__ -កាន់តែពិសេសអាចខលចេញ -__ggml_vocab_test__ -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) -__ggml_vocab_test__ -Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello - Hello -__ggml_vocab_test__ - ( -__ggml_vocab_test__ - - = -__ggml_vocab_test__ -' era -__ggml_vocab_test__ -Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ -__ggml_vocab_test__ -!!!!!! -__ggml_vocab_test__ -3 -__ggml_vocab_test__ -33 -__ggml_vocab_test__ -333 -__ggml_vocab_test__ -3333 -__ggml_vocab_test__ -33333 -__ggml_vocab_test__ -333333 -__ggml_vocab_test__ -3333333 -__ggml_vocab_test__ -33333333 -__ggml_vocab_test__ -333333333 -__ggml_vocab_test__ -Cửa Việt -__ggml_vocab_test__ - discards -__ggml_vocab_test__ - - - - - - - - - - - -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL -__ggml_vocab_test__ diff --git a/models/ggml-vocab-roberta-bpe.gguf.out b/models/ggml-vocab-roberta-bpe.gguf.out deleted file mode 100644 index f181ac3dc..000000000 --- a/models/ggml-vocab-roberta-bpe.gguf.out +++ /dev/null @@ -1,46 +0,0 @@ - 2550 204 18430 377 - 597 2768 298 8564 - - 1437 - 1437 1437 - 1437 1437 1437 - 50117 - 50118 - 50140 - 50140 50118 - 50117 50118 - 31414 232 - 20920 232 - 31414 623 - 20920 623 - 20920 623 328 - 31414 6 232 328 - 20920 6 232 328 - 42 16 8103 18164 27 4 49317 - 605 40976 262 10109 18474 385 29 36807 6455 - 36765 25482 22063 23171 34251 18697 10809 26161 18697 3602 22063 27969 40966 25417 15264 26161 24269 36709 41171 35328 - 1376 17772 7471 1376 17772 19002 1376 17772 9085 1376 4333 13859 1376 17772 9357 1376 4333 9264 1376 17772 25448 1376 17772 18400 1376 17772 4333 1376 4333 10172 1376 17772 4333 1376 17772 7258 1376 17772 19002 1376 17772 5782 1376 17772 10172 1376 17772 3726 1376 17772 5782 1376 4333 10172 1376 17772 23171 - 6569 15113 7471 36 21113 43 17841 19002 17 8384 6569 14285 4958 12605 36 34654 2841 4203 354 10146 26511 1070 43 36174 5782 36 8338 21554 14 34 63 308 19233 43 - 31414 - 20920 - 1437 20920 - 1437 1437 20920 - 1437 1437 1437 20920 - 1437 1437 1437 20920 50118 1437 1437 1437 20920 - 36 - 50118 5457 - 108 3567 - 31414 6 1423 108 1250 328 1336 32 47 17841 10172 17487 47876 3602 48617 15264 46537 11423 27326 48494 8210 49233 1558 1570 27761 49429 43251 10809 17772 - 32376 12846 - 246 - 3103 - 25631 - 46152 - 3103 25631 - 46152 3103 - 46152 25631 - 46152 46152 - 46152 3103 25631 - 347 1376 2023 12410 102 16376 1376 2023 6382 90 - 9553 5954 - 50118 1437 50140 1437 50140 50118 1437 50117 1437 50117 50117 1437 50117 50118 1437 1437 50118 1437 1437 1437 50118 1437 1437 1437 1437 50118 1437 1437 1437 1437 1437 50118 6569 15113 7471 36 21113 43 17841 19002 17 8384 6569 14285 4958 12605 36 34654 2841 4203 354 10146 26511 1070 43 36174 5782 8103 18164 27 6569 18164 27 155 2357 30242 155 25631 30242 3103 30242 25631 30242 46152 30242 3103 25631 155 4 246 155 7586 246 155 734 246 25974 17772 7471 1376 17772 19002 1376 17772 9085 1376 4333 13859 1376 17772 9357 1376 4333 9264 1376 17772 25448 1376 17772 18400 1376 17772 4333 1376 4333 10172 1376 17772 4333 1376 17772 7258 1376 17772 19002 1376 17772 5782 18636 10172 17487 47876 3602 48617 15264 46537 11423 27326 48494 8210 49233 1558 1570 27761 49429 43251 10809 17772 36738 48332 47463 18697 10809 25482 22063 23171 34251 18697 10809 26161 18697 3602 22063 27969 40966 25417 15264 26161 24269 36709 41171 35328 128 49690 108 49972 49519 12905 48149 48149 43796 32376 12846 27282 28749 38 348 57 128 41042 37 18 89 6 128 4629 47 686 116 128 448 45 686 38 581 146 24 6 128 495 47 101 103 6845 116 166 108 30660 10 108 462 574 diff --git a/models/ggml-vocab-starcoder.gguf.inp b/models/ggml-vocab-starcoder.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-starcoder.gguf.inp +++ b/models/ggml-vocab-starcoder.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-starcoder.gguf.out b/models/ggml-vocab-starcoder.gguf.out index 87e2465d3..4698e2c3c 100644 --- a/models/ggml-vocab-starcoder.gguf.out +++ b/models/ggml-vocab-starcoder.gguf.out @@ -1,5 +1,5 @@ 4850 244 57 244 162 159 17722 - 75 2022 3943 284 + 163 250 7146 361 244 280 diff --git a/models/templates/Qwen-QwQ-32B.jinja b/models/templates/Qwen-QwQ-32B.jinja new file mode 100644 index 000000000..d475f7068 --- /dev/null +++ b/models/templates/Qwen-QwQ-32B.jinja @@ -0,0 +1,62 @@ +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0]['role'] == 'system' %} + {{- messages[0]['content'] }} + {%- else %} + {{- '' }} + {%- endif %} + {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0]['role'] == 'system' %} + {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in messages %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" and not message.tool_calls %} + {%- set content = message.content %} + {%- if not loop.last %} + {%- set content = message.content.split('')[-1].lstrip('\n') %} + {%- endif %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set content = message.content %} + {%- if not loop.last %} + {%- set content = message.content.split('')[-1].lstrip('\n') %} + {%- endif %} + {{- '<|im_start|>' + message.role }} + {%- if message.content %} + {{- '\n' + content }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {{- tool_call.arguments | tojson }} + {{- '}\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n\n' }} +{%- endif %} diff --git a/models/templates/Qwen-Qwen3-0.6B.jinja b/models/templates/Qwen-Qwen3-0.6B.jinja new file mode 100644 index 000000000..699ff8df4 --- /dev/null +++ b/models/templates/Qwen-Qwen3-0.6B.jinja @@ -0,0 +1,85 @@ +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('') and message.content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} +{%- endfor %} +{%- for message in messages %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set content = message.content %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is defined and message.reasoning_content is not none %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in message.content %} + {%- set content = message.content.split('')[-1].lstrip('\n') %} + {%- set reasoning_content = message.content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- if loop.index0 > ns.last_query_index %} + {%- if loop.last or (not loop.last and reasoning_content) %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file diff --git a/models/templates/README.md b/models/templates/README.md index e4fd104fc..35b6386dd 100644 --- a/models/templates/README.md +++ b/models/templates/README.md @@ -19,4 +19,6 @@ These templates can be updated with the following commands: ./scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use > models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja ./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use > models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja ./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja +./scripts/get_chat_template.py Qwen/QwQ-32B > models/templates/Qwen-QwQ-32B.jinja +./scripts/get_chat_template.py Qwen/Qwen3-0.6B > models/templates/Qwen-Qwen3-0.6B.jinja ``` \ No newline at end of file diff --git a/requirements/requirements-convert_hf_to_gguf.txt b/requirements/requirements-convert_hf_to_gguf.txt index 8cb9c354f..431c596c1 100644 --- a/requirements/requirements-convert_hf_to_gguf.txt +++ b/requirements/requirements-convert_hf_to_gguf.txt @@ -1,3 +1,7 @@ -r ./requirements-convert_legacy_llama.txt --extra-index-url https://download.pytorch.org/whl/cpu -torch~=2.2.1 +torch~=2.2.1; platform_machine != "s390x" + +# torch s390x packages can only be found from nightly builds +--extra-index-url https://download.pytorch.org/whl/nightly +torch>=0.0.0.dev0; platform_machine == "s390x" diff --git a/requirements/requirements-convert_hf_to_gguf_update.txt b/requirements/requirements-convert_hf_to_gguf_update.txt index 8cb9c354f..431c596c1 100644 --- a/requirements/requirements-convert_hf_to_gguf_update.txt +++ b/requirements/requirements-convert_hf_to_gguf_update.txt @@ -1,3 +1,7 @@ -r ./requirements-convert_legacy_llama.txt --extra-index-url https://download.pytorch.org/whl/cpu -torch~=2.2.1 +torch~=2.2.1; platform_machine != "s390x" + +# torch s390x packages can only be found from nightly builds +--extra-index-url https://download.pytorch.org/whl/nightly +torch>=0.0.0.dev0; platform_machine == "s390x" diff --git a/requirements/requirements-convert_lora_to_gguf.txt b/requirements/requirements-convert_lora_to_gguf.txt index 5758076c4..d091d5648 100644 --- a/requirements/requirements-convert_lora_to_gguf.txt +++ b/requirements/requirements-convert_lora_to_gguf.txt @@ -1,2 +1,4 @@ -r ./requirements-convert_hf_to_gguf.txt --extra-index-url https://download.pytorch.org/whl/cpu +# torch s390x packages can only be found from nightly builds +--extra-index-url https://download.pytorch.org/whl/nightly diff --git a/requirements/requirements-gguf_editor_gui.txt b/requirements/requirements-gguf_editor_gui.txt index 920dc7cf9..fd253364e 100644 --- a/requirements/requirements-gguf_editor_gui.txt +++ b/requirements/requirements-gguf_editor_gui.txt @@ -1,3 +1,3 @@ numpy~=1.26.4 PySide6~=6.9.0 -gguf>=0.16.0 +gguf>=0.17.0 diff --git a/scripts/compare-commits.sh b/scripts/compare-commits.sh index e40d1cc6d..94a8eceb3 100755 --- a/scripts/compare-commits.sh +++ b/scripts/compare-commits.sh @@ -17,14 +17,14 @@ rm -f llama-bench.sqlite > /dev/null # to test a backend, call the script with the corresponding environment variable (e.g. GGML_CUDA=1 ./scripts/compare-commits.sh ...) if [ -n "$GGML_CUDA" ]; then - cmake_opts="-DGGML_CUDA=ON" + CMAKE_OPTS="${CMAKE_OPTS} -DGGML_CUDA=ON" fi dir="build-bench" function run { rm -fr ${dir} > /dev/null - cmake -B ${dir} -S . $cmake_opts > /dev/null + cmake -B ${dir} -S . ${CMAKE_OPTS} > /dev/null cmake --build ${dir} -t llama-bench > /dev/null ${dir}/bin/llama-bench -o sql -oe md $bench_args | sqlite3 llama-bench.sqlite } diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index b141cabc9..aa0fb8fb0 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -7c06c10c532a6cda913c17fc56341e8880ae341d +94a83ba5a725ae2aee79df75dd99b2119d0478cc diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py new file mode 100755 index 000000000..1151c9f01 --- /dev/null +++ b/scripts/sync_vendor.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 + +import urllib.request + +vendor = { + "https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp", + "https://github.com/nlohmann/json/releases/latest/download/json_fwd.hpp": "vendor/nlohmann/json_fwd.hpp", + + # sync manually + # "https://raw.githubusercontent.com/ochafik/minja/refs/heads/main/include/minja/minja.hpp": "vendor/minja/minja.hpp", + # "https://raw.githubusercontent.com/ochafik/minja/refs/heads/main/include/minja/chat-template.hpp": "vendor/minja/chat-template.hpp", + + "https://raw.githubusercontent.com/nothings/stb/refs/heads/master/stb_image.h": "vendor/stb/stb_image.h", + + "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.22/miniaudio.h": "vendor/miniaudio/miniaudio.h", + + "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.20.1/httplib.h": "vendor/cpp-httplib/httplib.h", +} + +for url, filename in vendor.items(): + print(f"downloading {url} to {filename}") # noqa: NP100 + urllib.request.urlretrieve(url, filename) diff --git a/scripts/tool_bench.py b/scripts/tool_bench.py index a2f2a2eb0..d8018e2e2 100755 --- a/scripts/tool_bench.py +++ b/scripts/tool_bench.py @@ -12,6 +12,7 @@ export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server export LLAMA_CACHE=${LLAMA_CACHE:-$HOME/Library/Caches/llama.cpp} + ./scripts/tool_bench.py run --n 10 --temp -1 --temp 0 --temp 1 --temp 2 --temp 5 --llama-baseline $PWD/buildMaster/bin/llama-server --output qwen14b.jsonl --hf bartowski/Qwen2.5-14B-Instruct-GGUF:Q4_K_L ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 1.5B Q4_K_M" --output qwen1.5b.jsonl --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 Coder 7B Q4_K_M" --output qwenc7b.jsonl --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b @@ -205,6 +206,7 @@ def run( model: Annotated[Optional[str], typer.Option(help="Name of the model to test (server agnostic)")] = None, hf: Annotated[Optional[str], typer.Option(help="GGUF huggingface model repo id (+ optional quant) to test w/ llama-server")] = None, chat_template: Annotated[Optional[str], typer.Option(help="Chat template override for llama-server")] = None, + chat_template_file: Annotated[Optional[str], typer.Option(help="Chat template file override for llama-server")] = None, ollama: Annotated[Optional[str], typer.Option(help="Ollama model tag to test")] = None, llama_baseline: Annotated[Optional[str], typer.Option(help="llama-server baseline binary path to use as baseline")] = None, n: Annotated[int, typer.Option(help="Number of times to run each test")] = 10, @@ -229,6 +231,12 @@ def run( # n_ctx = 8192 n_ctx = 2048 + if model is None: + if hf is not None: + model = hf.split("/")[-1] + elif ollama is not None: + model = ollama + assert force or append or not output.exists(), f"Output file already exists: {output}; use --force to overwrite" with output.open('a' if append else 'w') as output_file: @@ -320,6 +328,7 @@ def run( server.model_hf_repo = hf server.model_hf_file = None server.chat_template = chat_template + server.chat_template_file = chat_template_file server.server_path = server_path if port is not None: server.server_port = port @@ -335,6 +344,7 @@ def run( temp=t, output_kwargs=dict( chat_template=chat_template, + chat_template_file=chat_template_file, ), request_kwargs=dict( ignore_chat_grammar=ignore_chat_grammar, @@ -355,6 +365,7 @@ def run( temp=t, output_kwargs=dict( chat_template=None, + chat_template_file=None, ), request_kwargs=dict( model=ollama, diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d4bf37b1c..d20bd4fe2 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -14,12 +14,16 @@ add_library(llama llama-batch.cpp llama-chat.cpp llama-context.cpp + llama-cparams.cpp llama-grammar.cpp llama-graph.cpp llama-hparams.cpp llama-impl.cpp llama-io.cpp llama-kv-cache.cpp + llama-kv-cache-unified.cpp + llama-kv-cache-unified-iswa.cpp + llama-kv-cache-recurrent.cpp llama-memory.cpp llama-mmap.cpp llama-model-loader.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 19f58ba82..0ef81054b 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -175,6 +175,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_CONVNEXT_EMBEDDING_LENGTH, "%s.convnext.embedding_length" }, { LLM_KV_CONVNEXT_BLOCK_COUNT, "%s.convnext.block_count" }, + { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" }, + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, @@ -449,6 +451,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_TOKEN_TYPES, "token_types" }, { LLM_TENSOR_POS_EMBD, "position_embd" }, { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index c696b4113..1dcd4fa35 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -214,6 +214,8 @@ enum llm_kv { LLM_KV_CONVNEXT_EMBEDDING_LENGTH, LLM_KV_CONVNEXT_BLOCK_COUNT, + LLM_KV_CLASSIFIER_OUTPUT_LABELS, + // deprecated: LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID, diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index b98e3256c..6a19a2431 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -15,24 +15,31 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) { break; } } - ubatch_token.resize(!has_embd ? n_ubatch : 0); - ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0); - ubatch_pos.resize(n_ubatch); - ubatch_n_seq_id.resize(n_ubatch); - ubatch_seq_id.resize(n_ubatch); - ubatch_output.resize(n_ubatch); + + udatas.push_back({}); + + auto & udata = udatas.back(); + + udata.token.resize(!has_embd ? n_ubatch : 0); + udata.embd.resize(has_embd ? n_embd * n_ubatch : 0); + udata.pos.resize(n_ubatch); + udata.n_seq_id.resize(n_ubatch); + udata.seq_id.resize(n_ubatch); + udata.output.resize(n_ubatch); + llama_ubatch ubatch = { /*equal_seqs =*/ true, /*n_tokens =*/ 0, /*n_seq_tokens =*/ 0, /*n_seqs =*/ 0, - /*token =*/ !has_embd ? ubatch_token.data() : nullptr, - /*embd =*/ has_embd ? ubatch_embd.data() : nullptr, - /*pos =*/ ubatch_pos.data(), - /*n_seq_id =*/ ubatch_n_seq_id.data(), - /*seq_id =*/ ubatch_seq_id.data(), - /*output =*/ ubatch_output.data(), + /*token =*/ !has_embd ? udata.token.data() : nullptr, + /*embd =*/ has_embd ? udata.embd.data() : nullptr, + /*pos =*/ udata.pos.data(), + /*n_seq_id =*/ udata.n_seq_id.data(), + /*seq_id =*/ udata.seq_id.data(), + /*output =*/ udata.output.data(), }; + return ubatch; } diff --git a/src/llama-batch.h b/src/llama-batch.h index 6305051b6..b8260b94f 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -11,15 +11,15 @@ struct llama_ubatch { bool equal_seqs; // TODO: whole_seqs for embeddings? - uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) + uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) uint32_t n_seq_tokens; // tokens per sequence uint32_t n_seqs; llama_token * token; // [n_tokens] float * embd; // [n_embd, n_tokens] llama_pos * pos; // [n_tokens] - int32_t * n_seq_id; // [n_seqs] - llama_seq_id ** seq_id; // [n_seqs] + int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence + llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id; int8_t * output; // [n_tokens] }; @@ -49,13 +49,18 @@ struct llama_sbatch { const llama_batch * batch = nullptr; - // buffers for the ubatch - std::vector ubatch_token; - std::vector ubatch_embd; - std::vector ubatch_pos; - std::vector ubatch_n_seq_id; - std::vector ubatch_seq_id; - std::vector ubatch_output; + // buffers for the ubatches + // TODO: very hacky, this needs a complete rework + struct ubatch_data { + std::vector token; + std::vector embd; + std::vector pos; + std::vector n_seq_id; + std::vector seq_id; + std::vector output; + }; + + std::vector udatas; llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 85b4324b6..4ab574387 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -6,9 +6,10 @@ #include "llama-model.h" #include "llama-kv-cache.h" -#include -#include #include +#include +#include +#include // // llama_context @@ -25,7 +26,11 @@ llama_context::llama_context( const auto & hparams = model.hparams; - cparams.n_seq_max = std::max(1u, params.n_seq_max); + cparams.n_seq_max = std::max(1u, params.n_seq_max); + if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) { + throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES)); + } + cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor; @@ -118,6 +123,11 @@ llama_context::llama_context( __func__, n_ctx_per_seq, hparams.n_ctx_train); } + if (!params.swa_full && cparams.n_seq_max > 1) { + LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n", + __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573"); + } + if (!hparams.vocab_only) { // GPU backends for (auto * dev : model.devices) { @@ -255,15 +265,9 @@ llama_context::llama_context( // reserve worst-case graph if (!hparams.vocab_only && memory) { - const uint32_t n_seqs = 1; // TODO: worst-case number of sequences + const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - - // restore later - // TODO: something cleaner - const auto n_outputs_save = n_outputs; - LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); int n_splits_pp = -1; @@ -275,23 +279,17 @@ llama_context::llama_context( // simulate full KV cache llama_kv_cache * kv_self = static_cast(memory.get()); - kv_self->set_full(); + const auto kv_state = kv_self->init_full(); + if (!kv_state) { + throw std::runtime_error("failed to initialize KV cache"); + } cross.v_embd.clear(); // reserve pp graph first so that buffers are only allocated once { - llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - - // max number of outputs - n_outputs = ubatch_pp.n_tokens; - - LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs); - - auto * gf = graph_init(); - graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT); - - if (!ggml_backend_sched_reserve(sched.get(), gf)) { + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); + if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } @@ -301,16 +299,8 @@ llama_context::llama_context( // reserve with tg graph to get the number of splits and nodes { - llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - - n_outputs = ubatch_tg.n_tokens; - - LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs); - - auto * gf = graph_init(); - graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT); - - if (!ggml_backend_sched_reserve(sched.get(), gf)) { + auto * gf = graph_reserve(1, 1, 1, kv_state.get()); + if (!gf) { throw std::runtime_error("failed to allocate compute tg buffers"); } @@ -320,22 +310,12 @@ llama_context::llama_context( // reserve again with pp graph to avoid ggml-alloc reallocations during inference { - llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - - n_outputs = ubatch_pp.n_tokens; - - LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs); - - auto * gf = graph_init(); - graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT); - - if (!ggml_backend_sched_reserve(sched.get(), gf)) { + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); + if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } } - n_outputs = n_outputs_save; - for (size_t i = 0; i < backend_ptrs.size(); ++i) { ggml_backend_t backend = backend_ptrs[i]; ggml_backend_buffer_type_t buft = backend_buft[i]; @@ -449,36 +429,33 @@ const llama_kv_cache * llama_context::get_kv_self() const { return kv_self; } -void llama_context::kv_self_update() { - bool need_reserve = false; +bool llama_context::kv_self_update() { + if (!memory) { + return false; + } llama_kv_cache * kv_self = static_cast(memory.get()); - need_reserve = kv_self->update(*this); - - // reserve a worst case graph if needed - if (need_reserve) { - LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__); - - // build worst-case graph - uint32_t n_seqs = 1; // TODO: worst-case number of sequences - uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - - // simulate full KV cache - kv_self->set_full(); - - llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - - auto * gf = graph_init(); - graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT); - - // initialize scheduler with the worst-case graph - ggml_backend_sched_reset(sched.get()); - if (!ggml_backend_sched_reserve(sched.get(), gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); - } + if (!kv_self->update(*this)) { + // no updates have been performed + return false; } + + // if the KV cache did any computation, we have to reserve a new worst-case graph + const auto kv_state = kv_self->init_full(); + if (!kv_state) { + throw std::runtime_error("failed to initialize KV cache"); + } + + const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); + if (!gf) { + LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__); + } + + return true; } enum llama_pooling_type llama_context::pooling_type() const { @@ -672,6 +649,49 @@ bool llama_context::apply_adapter_cvec( return cvec.apply(model, data, len, n_embd, il_start, il_end); } +llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) { + if (mstate && !mstate->apply()) { + LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__); + ret = GGML_STATUS_FAILED; + return nullptr; + } + + auto * gf = graph_init(); + if (!gf) { + LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); + ret = GGML_STATUS_FAILED; + return nullptr; + } + + auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate); + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__); + ret = GGML_STATUS_FAILED; + return nullptr; + } + + // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); + + if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); + ret = GGML_STATUS_ALLOC_FAILED; + return nullptr; + } + + res->set_inputs(&ubatch); + + const auto status = graph_compute(gf, ubatch.n_tokens > 1); + if (status != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status); + ret = status; + return nullptr; + } + + ret = GGML_STATUS_SUCCESS; + + return res; +} + int llama_context::encode(llama_batch & inp_batch) { if (inp_batch.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); @@ -689,12 +709,18 @@ int llama_context::encode(llama_batch & inp_batch) { GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + // TODO: move the validation to the llama_batch_allocr if (batch.token) { for (int32_t i = 0; i < n_tokens; ++i) { if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); return -1; } + + if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) { + LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES); + throw -1; + } } } @@ -727,8 +753,6 @@ int llama_context::encode(llama_batch & inp_batch) { n_outputs = n_tokens; - //batch_manager->prepare(ubatch); - ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); @@ -739,26 +763,18 @@ int llama_context::encode(llama_batch & inp_batch) { // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223 cparams.causal_attn = false; - auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER); - - ggml_backend_sched_alloc_graph(sched.get(), gf); - - res->set_inputs(&ubatch); + ggml_status status; + const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); cparams.causal_attn = causal_attn_org; - const auto compute_status = graph_compute(gf, n_tokens > 1); - switch (compute_status) { - case GGML_STATUS_SUCCESS: - break; - case GGML_STATUS_ABORTED: - return 2; - case GGML_STATUS_ALLOC_FAILED: - return -2; - case GGML_STATUS_FAILED: - default: - return -3; + if (!res) { + switch (status) { + case GGML_STATUS_ABORTED: return 2; + case GGML_STATUS_ALLOC_FAILED: return -2; + case GGML_STATUS_FAILED: return -3; + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); + } } auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); @@ -848,7 +864,7 @@ int llama_context::encode(llama_batch & inp_batch) { int llama_context::decode(llama_batch & inp_batch) { if (!memory) { - LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__); + LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); return encode(inp_batch); } @@ -879,15 +895,19 @@ int llama_context::decode(llama_batch & inp_batch) { const int64_t n_tokens_all = batch.n_tokens; const int64_t n_embd = hparams.n_embd; - llama_kv_cache_guard kv_guard(kv_self); - GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + // TODO: move the validation to the llama_batch_allocr if (batch.token) { for (int64_t i = 0; i < n_tokens_all; ++i) { if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]); - throw std::runtime_error("invalid token"); + return -1; + } + + if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) { + LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES); + return -1; } } } @@ -920,7 +940,48 @@ int llama_context::decode(llama_batch & inp_batch) { n_outputs_all = 1; } - llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all); + // handle any pending defrags/shifts + kv_self_update(); + + llama_memory_state_ptr kv_state; + + bool did_defrag = false; + + while (true) { + kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all); + if (!kv_state) { + return -2; + } + + switch (kv_state->get_status()) { + case LLAMA_MEMORY_STATUS_SUCCESS: + { + } break; + case LLAMA_MEMORY_STATUS_FAILED_PREPARE: + { + if (!did_defrag) { + did_defrag = true; + + kv_self->defrag_sched(-1.0f); + if (kv_self_update()) { + LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens); + + continue; + } + } + + LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens); + + return 1; + } + case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: + { + return -2; + } + } + + break; + } // reserve output buffer if (output_reserve(n_outputs_all) < n_outputs_all) { @@ -928,13 +989,10 @@ int llama_context::decode(llama_batch & inp_batch) { return -2; }; - // handle any pending defrags/shifts - kv_self_update(); - int64_t n_outputs_prev = 0; - while (sbatch.n_tokens > 0) { - llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled); + do { + const auto & ubatch = kv_state->get_ubatch(); // count the outputs in this u_batch { @@ -953,33 +1011,37 @@ int llama_context::decode(llama_batch & inp_batch) { n_outputs = n_outputs_new; } - // find KV slot - if (!kv_self->find_slot(ubatch)) { - return 1; - } - ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); - auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER); + ggml_status status; + const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, kv_state.get(), status); - // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); + if (!res) { + // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache + llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits::max() }; - ggml_backend_sched_alloc_graph(sched.get(), gf); + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + const auto & seq_id = ubatch.seq_id[i][0]; - res->set_inputs(&ubatch); + pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]); + } - const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1); - if (compute_status != GGML_STATUS_SUCCESS) { - switch (compute_status) { - case GGML_STATUS_ABORTED: - return 2; - case GGML_STATUS_ALLOC_FAILED: - return -2; - case GGML_STATUS_FAILED: - default: - return -3; + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (pos_min[s] == std::numeric_limits::max()) { + continue; + } + + LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]); + + llama_kv_self_seq_rm(this, s, pos_min[s], -1); + } + + switch (status) { + case GGML_STATUS_ABORTED: return 2; + case GGML_STATUS_ALLOC_FAILED: return -2; + case GGML_STATUS_FAILED: return -3; + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); } } @@ -1066,10 +1128,7 @@ int llama_context::decode(llama_batch & inp_batch) { } n_outputs_prev += n_outputs; - } - - // finalize the batch processing - kv_guard.commit(); + } while (kv_state->next()); // set to total number of outputs in the batch, for use in llama_get_logits_ith n_outputs = n_outputs_all; @@ -1078,7 +1137,7 @@ int llama_context::decode(llama_batch & inp_batch) { { bool sorted_output = true; - auto & out_ids = sbatch.out_ids; + auto & out_ids = kv_state->out_ids(); GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all); @@ -1238,11 +1297,52 @@ ggml_cgraph * llama_context::graph_init() { return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false); } +ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) { + LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs); + + if (n_tokens % n_seqs != 0) { + n_tokens = (n_tokens / n_seqs) * n_seqs; + n_outputs = std::min(n_outputs, n_tokens); + + LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs); + } + + // store the n_outputs as it is, and restore it afterwards + // TODO: not sure if needed, might simplify in the future by removing this + const auto save_n_outputs = this->n_outputs; + + this->n_outputs = n_outputs; + + llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + + auto * gf = graph_init(); + auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate); + + this->n_outputs = save_n_outputs; + + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__); + return nullptr; + } + + ggml_backend_sched_reset(sched.get()); + + // initialize scheduler with the specified graph + if (!ggml_backend_sched_reserve(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); + return nullptr; + } + + return gf; +} + llm_graph_result_ptr llama_context::graph_build( - ggml_context * ctx, - ggml_cgraph * gf, - const llama_ubatch & ubatch, - llm_graph_type gtype) { + ggml_context * ctx, + ggml_cgraph * gf, + const llama_ubatch & ubatch, + llm_graph_type gtype, + const llama_memory_state_i * mstate) { return model.build_graph( { /*.ctx =*/ ctx, @@ -1254,7 +1354,7 @@ llm_graph_result_ptr llama_context::graph_build( /*.backend_cpu =*/ backend_cpu, /*.cvec =*/ &cvec, /*.loras =*/ &loras, - /*.memory =*/ memory.get(), + /*.mstate =*/ mstate, /*.cross =*/ &cross, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), @@ -1935,7 +2035,6 @@ void llama_context::opt_epoch_iter( llama_kv_cache * kv_self = static_cast(memory.get()); kv_self->clear(); - llama_kv_cache_guard kv_guard(kv_self); for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) { batch.n_tokens = n_batch; @@ -1958,7 +2057,11 @@ void llama_context::opt_epoch_iter( int64_t n_outputs_all = n_tokens_all; - llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true); + auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true); + if (!kv_state || kv_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__); + break; + } // reserve output buffer if (output_reserve(n_outputs_all) < n_outputs_all) { @@ -1966,20 +2069,19 @@ void llama_context::opt_epoch_iter( GGML_ABORT("TODO: handle this error"); }; - for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) { - llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled); + uint32_t pos_batch = 0; + do { + const auto & ubatch = kv_state->get_ubatch(); n_outputs = ubatch.n_tokens; - // TODO: not sure if this is needed - if (!kv_self->find_slot(ubatch)) { - LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens); - - GGML_ABORT("TODO: handle this error"); + if (!kv_state->apply()) { + LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__); + break; } auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT); + auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get()); struct ggml_context * ctx_compute_opt; { @@ -1994,6 +2096,7 @@ void llama_context::opt_epoch_iter( } ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits()); ggml_opt_alloc(opt_ctx, train); + res->set_inputs(&ubatch); { struct ggml_tensor * labels = ggml_opt_labels(opt_ctx); @@ -2011,10 +2114,10 @@ void llama_context::opt_epoch_iter( callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start); } ggml_free(ctx_compute_opt); - } - } - kv_guard.commit(); + pos_batch += ubatch.n_tokens; + } while (kv_state->next()); + } } void llama_context::opt_epoch( @@ -2178,6 +2281,7 @@ llama_kv_cache * llama_get_kv_self(llama_context * ctx) { return ctx->get_kv_self(); } +// deprecated void llama_kv_self_update(llama_context * ctx) { ctx->kv_self_update(); } @@ -2432,6 +2536,7 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { return kv->seq_pos_max(seq_id); } +// deprecated void llama_kv_self_defrag(llama_context * ctx) { auto * kv = ctx->get_kv_self(); if (!kv) { @@ -2573,22 +2678,8 @@ int32_t llama_encode( int32_t llama_decode( llama_context * ctx, llama_batch batch) { - int ret = ctx->decode(batch); - - // defrag and try again - // TODO: distinguish return code when we are sure that even after defrag there is no space available - if (ret == 1) { - llama_kv_self_defrag(ctx); - ret = ctx->decode(batch); - - if (ret == 1) { - LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens); - - return ret; - } - } - - if (ret != 0) { + const int ret = ctx->decode(batch); + if (ret != 0 && ret != 1) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } diff --git a/src/llama-context.h b/src/llama-context.h index c0ceacb10..3b880286b 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -18,6 +18,9 @@ struct llama_kv_cache; class llama_io_read_i; class llama_io_write_i; +class llama_memory_i; +class llama_memory_state_i; + struct llama_context { // init scheduler and compute buffers, reserve worst-case graphs llama_context( @@ -47,7 +50,9 @@ struct llama_context { llama_kv_cache * get_kv_self(); const llama_kv_cache * get_kv_self() const; - void kv_self_update(); + // return true of the KV cache was updated + // TODO: remove + bool kv_self_update(); enum llama_pooling_type pooling_type() const; @@ -88,6 +93,16 @@ struct llama_context { int32_t il_start, int32_t il_end); + // process a single ubatch with a specific graph type + // if memory_state is provided, it will be applied first to the context's memory + // ret contains the status of the graph computation + // returns nullptr only if ret != GGML_STATUS_SUCCESS + llm_graph_result_ptr process_ubatch( + const llama_ubatch & ubatch, + llm_graph_type gtype, + llama_memory_state_i * mstate, + ggml_status & ret); + int encode(llama_batch & inp_batch); int decode(llama_batch & inp_batch); @@ -180,16 +195,18 @@ public: ggml_cgraph * graph_init(); // returns the result of ggml_backend_sched_graph_compute_async execution - ggml_status graph_compute( - ggml_cgraph * gf, - bool batched); + ggml_status graph_compute(ggml_cgraph * gf, bool batched); + + // reserve a graph with a dummy ubatch of the specified size + ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate); private: llm_graph_result_ptr graph_build( - ggml_context * ctx, - ggml_cgraph * gf, - const llama_ubatch & ubatch, - llm_graph_type gtype); + ggml_context * ctx, + ggml_cgraph * gf, + const llama_ubatch & ubatch, + llm_graph_type gtype, + const llama_memory_state_i * mstate); llm_graph_cb graph_get_cb() const; diff --git a/src/llama-cparams.cpp b/src/llama-cparams.cpp index 28369be36..f7b36590f 100644 --- a/src/llama-cparams.cpp +++ b/src/llama-cparams.cpp @@ -1 +1,5 @@ #include "llama-cparams.h" + +size_t llama_max_parallel_sequences(void) { + return LLAMA_MAX_PARALLEL_SEQUENCES; +} diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 246fa5777..2871031ef 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -4,6 +4,8 @@ #include +#define LLAMA_MAX_PARALLEL_SEQUENCES 64 + struct llama_cparams { uint32_t n_ctx; // context size used during inference uint32_t n_batch; diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 973b47ae0..bed706bb2 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1177,8 +1177,18 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token for (const auto & trigger_pattern : grammar.trigger_patterns) { if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) { grammar.awaiting_trigger = false; - // get from the first match to the end of the string - auto constrained_str = grammar.trigger_buffer.substr(match.position(1)); + // get from the first matched capturing group to the end of the string + size_t start = std::string::npos; + for (auto i = 1u; i < match.size(); i++) { + if (match.length(i) > 0) { + start = match.position(i); + break; + } + } + if (start == std::string::npos) { + start = match.position(0); + } + auto constrained_str = grammar.trigger_buffer.substr(start); // std::string constrained_str(match[1].first, grammar.trigger_buffer.end()); grammar.trigger_buffer.clear(); llama_grammar_accept_str(grammar, constrained_str); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 13e36d161..727e119e3 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -3,7 +3,10 @@ #include "llama-impl.h" #include "llama-batch.h" #include "llama-cparams.h" -#include "llama-kv-cache.h" + +#include "llama-kv-cache-unified.h" +#include "llama-kv-cache-unified-iswa.h" +#include "llama-kv-cache-recurrent.h" #include #include @@ -83,7 +86,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) { void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) { if (pos_bucket) { - kv_self->set_input_pos_bucket(pos_bucket, ubatch); + kv_state->set_input_pos_bucket(pos_bucket, ubatch); } } @@ -234,7 +237,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); - const int64_t n_kv = kv_self->n; + const int64_t n_kv = kv_state->get_n_kv(); if (s_copy) { GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer)); @@ -242,7 +245,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n for (uint32_t i = 0; i < n_kv; ++i) { - data[i] = kv_self->s_copy(i); + data[i] = kv_state->s_copy(i); } } } @@ -250,7 +253,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); - const int64_t n_kv = kv_self->n; + const int64_t n_kv = kv_state->get_n_kv(); if (s_mask) { GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer)); @@ -258,7 +261,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) { // clear unused states for (int i = 0; i < n_kv; ++i) { - data[i] = kv_self->s_mask(i); + data[i] = kv_state->s_mask(i); } } } @@ -362,17 +365,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { if (self_kq_mask) { - kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } } void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { if (self_kq_mask) { - kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } if (self_kq_mask_swa) { - kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); } } @@ -448,14 +451,14 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : backend_cpu (params.backend_cpu), cvec (params.cvec), loras (params.loras), - memory (params.memory), + mstate (params.mstate), cross (params.cross), cb_func (params.cb), res (std::make_unique()) { } int64_t llm_graph_context::n_pos_per_embd() const { - return arch == LLM_ARCH_QWEN2VL ? 4 : 1; + return hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1; } void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const { @@ -954,11 +957,11 @@ ggml_tensor * llm_graph_context::build_inp_cls() const { } ggml_tensor * llm_graph_context::build_inp_s_copy() const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(kv_self); + auto inp = std::make_unique(kv_state); - const auto n_kv = kv_self->n; + const auto n_kv = kv_state->get_n_kv(); auto & cur = inp->s_copy; @@ -971,11 +974,11 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const { } ggml_tensor * llm_graph_context::build_inp_s_mask() const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(kv_self); + auto inp = std::make_unique(kv_state); - const auto n_kv = kv_self->n; + const auto n_kv = kv_state->get_n_kv(); auto & cur = inp->s_mask; @@ -1025,11 +1028,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const { } ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(hparams, kv_self); + auto inp = std::make_unique(hparams, kv_state); - const auto n_kv = kv_self->get_n(); + const auto n_kv = kv_state->get_n_kv(); auto & cur = inp->pos_bucket; @@ -1231,14 +1234,14 @@ ggml_tensor * llm_graph_context::build_attn( } llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(hparams, cparams, kv_self); + auto inp = std::make_unique(hparams, cparams, kv_state); { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); - const auto n_kv = kv_self->get_n(); + const auto n_kv = kv_state->get_n_kv(); inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp->self_kq_mask, "KQ_mask", -1); @@ -1268,25 +1271,29 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - const llama_kv_cache_unified * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); // store to KV cache { - ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il)); - ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il)); + ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); + ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); } const auto & kq_mask = inp->get_kq_mask(); ggml_tensor * q = q_cur; - ggml_tensor * k = kv_self->get_k(ctx0, il); - ggml_tensor * v = kv_self->get_v(ctx0, il); + ggml_tensor * k = kv_state->get_k(ctx0, il); + ggml_tensor * v = kv_state->get_v(ctx0, il); ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); cb(cur, "kqv_out", il); if (wo) { cur = build_lora_mm(wo, cur); + if (arch == LLM_ARCH_GLM4) { + // GLM4 seems to have numerical issues with half-precision accumulators + ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + } } if (wo_b) { @@ -1297,12 +1304,12 @@ ggml_tensor * llm_graph_context::build_attn( } llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { - const llama_kv_cache_unified_iswa * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(hparams, cparams, kv_self); + auto inp = std::make_unique(hparams, cparams, kv_state); { - const auto n_kv = kv_self->get_kv_base()->get_n(); + const auto n_kv = kv_state->get_base()->get_n_kv(); inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp->self_kq_mask, "KQ_mask", -1); @@ -1314,7 +1321,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif { GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA"); - const auto n_kv = kv_self->get_kv_swa()->get_n(); + const auto n_kv = kv_state->get_swa()->get_n_kv(); inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1); @@ -1344,33 +1351,29 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); + const auto * kv_state_iswa = static_cast(mstate); + const bool is_swa = hparams.is_swa(il); - const llama_kv_cache_unified_iswa * kv_self = static_cast(memory); - - const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base(); + const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base(); // store to KV cache { - ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il)); - ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il)); + ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); + ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); } const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); ggml_tensor * q = q_cur; - ggml_tensor * k = kv->get_k(ctx0, il); - ggml_tensor * v = kv->get_v(ctx0, il); + ggml_tensor * k = kv_state->get_k(ctx0, il); + ggml_tensor * v = kv_state->get_v(ctx0, il); ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); cb(cur, "kqv_out", il); if (wo) { cur = build_lora_mm(wo, cur); - if (arch == LLM_ARCH_GLM4) { - // GLM4 seems to have numerical issues with half-precision accumulators - ggml_mul_mat_set_prec(cur, GGML_PREC_F32); - } } if (wo_b) { @@ -1446,12 +1449,12 @@ ggml_tensor * llm_graph_context::build_copy_mask_state( ggml_tensor * state_mask, int32_t n_state, int32_t n_seqs) const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - const auto n_kv = kv_self->n; - const auto kv_head = kv_self->head; + const auto n_kv = kv_state->get_n_kv(); + const auto kv_head = kv_state->get_head(); - ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size); + ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size()); // copy states // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv @@ -1478,13 +1481,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); const auto token_shift_count = hparams.token_shift_count; const int64_t n_seqs = ubatch.n_seqs; - ggml_tensor * token_shift_all = kv_self->k_l[il]; + ggml_tensor * token_shift_all = kv_state->get_k_l(il); ggml_tensor * token_shift = build_copy_mask_state( gf, token_shift_all, state_copy, state_mask, @@ -1499,19 +1502,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( ggml_tensor * token_shift, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); const auto token_shift_count = hparams.token_shift_count; const auto n_embd = hparams.n_embd; const int64_t n_seqs = ubatch.n_seqs; - const auto kv_head = kv_self->head; + const auto kv_head = kv_state->get_head(); return ggml_cpy( ctx0, ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0), - ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il])) + ggml_view_1d(ctx0, kv_state->get_k_l(il), hparams.n_embd_k_s()*n_seqs, hparams.n_embd_k_s()*kv_head*ggml_element_size(kv_state->get_k_l(il))) ); } @@ -1562,20 +1565,25 @@ void llm_graph_context::build_pooling( ggml_tensor * inp_cls = build_inp_cls(); inp = ggml_get_rows(ctx0, inp, inp_cls); - // classification head - // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566 - GGML_ASSERT(cls != nullptr); - GGML_ASSERT(cls_b != nullptr); + if (cls != nullptr && cls_b != nullptr) { + // classification head + // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566 + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b); + cur = ggml_tanh(ctx0, cur); - cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b); - cur = ggml_tanh(ctx0, cur); - - // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en - // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896 - if (cls_out) { + // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en + // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896 + if (cls_out) { + GGML_ASSERT(cls_out_b != nullptr); + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b); + } + } else if (cls_out) { + // Single layer classification head (direct projection) + // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476 GGML_ASSERT(cls_out_b != nullptr); - - cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b); + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b); + } else { + GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b"); } } break; default: diff --git a/src/llama-graph.h b/src/llama-graph.h index 2b85bb25b..d1c5dd1bf 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -17,10 +17,11 @@ struct ggml_tensor; struct llama_ubatch; struct llama_cparams; -class llama_memory_i; -class llama_kv_cache_unified; -class llama_kv_cache_unified_iswa; -class llama_kv_cache_recurrent; +class llama_memory_state_i; + +class llama_kv_cache_unified_state; +class llama_kv_cache_unified_iswa_state; +class llama_kv_cache_recurrent_state; // certain models (typically multi-modal) can produce different types of graphs enum llm_graph_type { @@ -133,7 +134,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i { public: llm_graph_input_pos_bucket_kv( const llama_hparams & hparams, - const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {} + const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {} virtual ~llm_graph_input_pos_bucket_kv() = default; void set_input(const llama_ubatch * ubatch) override; @@ -141,7 +142,7 @@ public: ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch] const llama_hparams & hparams; - const llama_kv_cache_unified * kv_self; + const llama_kv_cache_unified_state * kv_state; }; class llm_graph_input_out_ids : public llm_graph_input_i { @@ -188,26 +189,26 @@ public: class llm_graph_input_s_copy : public llm_graph_input_i { public: - llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {} + llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {} virtual ~llm_graph_input_s_copy() = default; void set_input(const llama_ubatch * ubatch) override; ggml_tensor * s_copy; // I32 [kv_size] - const llama_kv_cache_recurrent * kv_self; + const llama_kv_cache_recurrent_state * kv_state; }; class llm_graph_input_s_mask : public llm_graph_input_i { public: - llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {} + llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {} virtual ~llm_graph_input_s_mask() = default; void set_input(const llama_ubatch * ubatch) override; ggml_tensor * s_mask; // F32 [1, n_kv] - const llama_kv_cache_recurrent * kv_self; + const llama_kv_cache_recurrent_state * kv_state; }; class llm_graph_input_cross_embd : public llm_graph_input_i { @@ -247,10 +248,10 @@ public: llm_graph_input_attn_kv_unified( const llama_hparams & hparams, const llama_cparams & cparams, - const llama_kv_cache_unified * kv_self) : + const llama_kv_cache_unified_state * kv_state) : hparams(hparams), cparams(cparams), - kv_self(kv_self) { + kv_state(kv_state) { } ~llm_graph_input_attn_kv_unified() = default; @@ -264,7 +265,7 @@ public: const llama_hparams & hparams; const llama_cparams & cparams; - const llama_kv_cache_unified * kv_self; + const llama_kv_cache_unified_state * kv_state; }; class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { @@ -272,10 +273,10 @@ public: llm_graph_input_attn_kv_unified_iswa( const llama_hparams & hparams, const llama_cparams & cparams, - const llama_kv_cache_unified_iswa * kv_self) : + const llama_kv_cache_unified_iswa_state * kv_state) : hparams(hparams), cparams(cparams), - kv_self(kv_self) { + kv_state(kv_state) { } ~llm_graph_input_attn_kv_unified_iswa() = default; @@ -292,7 +293,7 @@ public: const llama_hparams & hparams; const llama_cparams & cparams; - const llama_kv_cache_unified_iswa * kv_self; + const llama_kv_cache_unified_iswa_state * kv_state; }; class llm_graph_input_attn_cross : public llm_graph_input_i { @@ -383,10 +384,10 @@ struct llm_graph_params { ggml_backend_sched_t sched; ggml_backend_t backend_cpu; - const llama_adapter_cvec * cvec; - const llama_adapter_loras * loras; - const llama_memory_i * memory; - const llama_cross * cross; + const llama_adapter_cvec * cvec; + const llama_adapter_loras * loras; + const llama_memory_state_i * mstate; + const llama_cross * cross; int32_t n_outputs; @@ -435,10 +436,10 @@ struct llm_graph_context { ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove? - const llama_adapter_cvec * cvec; - const llama_adapter_loras * loras; - const llama_memory_i * memory; - const llama_cross * cross; + const llama_adapter_cvec * cvec; + const llama_adapter_loras * loras; + const llama_memory_state_i * mstate; + const llama_cross * cross; const llm_graph_cb & cb_func; diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 4f84e56b3..1499eb08a 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -2,6 +2,22 @@ #include "ggml.h" +void llama_hparams::set_swa_pattern(uint32_t n_pattern) { + for (uint32_t il = 0; il < n_layer; ++il) { + swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); + } +} + +bool llama_hparams::is_swa_any() const { + for (uint32_t il = 0; il < n_layer; ++il) { + if (swa_layers[il]) { + return true; + } + } + + return false; +} + uint32_t llama_hparams::n_head(uint32_t il) const { if (il < n_layer) { return n_head_arr[il]; @@ -72,7 +88,7 @@ uint32_t llama_hparams::n_embd_v_s() const { bool llama_hparams::is_swa(uint32_t il) const { if (il < n_layer) { - return n_swa_pattern == 0 || (il % n_swa_pattern < (n_swa_pattern - 1)); + return swa_layers[il]; } GGML_ABORT("fatal error"); diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 5222eedcf..b2bcb8b01 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -102,20 +102,12 @@ struct llama_hparams { // Sliding Window Attention (SWA) llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; - - uint32_t n_swa = 0; // the size of the sliding window (0 - no SWA) - uint32_t n_swa_pattern = 1; // this value n means that every nth layer is dense (i.e. non-SWA) - // by default n == 1, all layers are dense - // note that if n_swa_pattern == 0, all layers are SWA - // example: n_swa_pattern = 3 - // il == 0: swa - // il == 1: swa - // il == 2: dense - // il == 3: swa - // il == 4: swa - // il == 5: dense - // il == 6: swa - // etc ... + // the size of the sliding window (0 - no SWA) + uint32_t n_swa = 0; + // if swa_layers[il] == true, then layer il is SWA + // if swa_layers[il] == false, then layer il is dense (i.e. non-SWA) + // by default, all layers are dense + std::array swa_layers; // for State Space Models uint32_t ssm_d_conv = 0; @@ -139,6 +131,9 @@ struct llama_hparams { bool attn_soft_cap = false; bool use_kq_norm = true; + // for Classifiers + uint32_t n_cls_out = 1; + // llama4 uint32_t n_moe_layer_step = 0; uint32_t n_no_rope_layer_step = 4; @@ -153,6 +148,23 @@ struct llama_hparams { enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; + // this value n_pattern means that every nth layer is dense (i.e. non-SWA) + // note that if n_pattern == 0, all layers are SWA + // if n_pattern == 1, all layers are dense + // example: n_pattern = 3 + // il == 0: swa + // il == 1: swa + // il == 2: dense + // il == 3: swa + // il == 4: swa + // il == 5: dense + // il == 6: swa + // etc ... + void set_swa_pattern(uint32_t n_pattern); + + // return true if one of the layers is SWA + bool is_swa_any() const; + uint32_t n_head(uint32_t il = 0) const; uint32_t n_head_kv(uint32_t il = 0) const; diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp new file mode 100644 index 000000000..641eab2f3 --- /dev/null +++ b/src/llama-kv-cache-recurrent.cpp @@ -0,0 +1,1132 @@ +#include "llama-kv-cache-recurrent.h" + +#include "llama-impl.h" +#include "llama-batch.h" +#include "llama-model.h" + +#include +#include +#include +#include +#include + +// +// llama_kv_cache_recurrent +// + +llama_kv_cache_recurrent::llama_kv_cache_recurrent( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { + const int32_t n_layer = hparams.n_layer; + + LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n", + __func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer); + + head = 0; + size = kv_size; + used = 0; + + cells.clear(); + cells.resize(kv_size); + + // create a context for each buffer type + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return nullptr; + } + + ctx_map[buft] = ctx; + ctxs.emplace_back(ctx); + + return ctx; + } + + return it->second; + }; + + k_l.reserve(n_layer); + v_l.reserve(n_layer); + + for (int i = 0; i < n_layer; i++) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + + const char * dev_name = "CPU"; + + ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); + + if (offload) { + auto * dev = model.dev_layer(i); + buft = ggml_backend_dev_buffer_type(dev); + + dev_name = ggml_backend_dev_name(dev); + } + + LLAMA_LOG_DEBUG("%s, layer %3d: dev = %s\n", __func__, i, dev_name); + + ggml_context * ctx = ctx_for_buft(buft); + if (!ctx) { + throw std::runtime_error("failed to create ggml context for kv cache"); + } + + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_format_name(k, "cache_k_l%d", i); + ggml_format_name(v, "cache_v_l%d", i); + k_l.push_back(k); + v_l.push_back(v); + } + + // allocate tensors and initialize the buffers to avoid NaNs in the padding + for (auto it : ctx_map) { + auto * buft = it.first; + auto * ctx = it.second; + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + throw std::runtime_error("failed to allocate buffer for kv cache"); + } + ggml_backend_buffer_clear(buf, 0); + LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + bufs.emplace_back(buf); + } + + { + const size_t memory_size_k = size_k_bytes(); + const size_t memory_size_v = size_v_bytes(); + + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } +} + +void llama_kv_cache_recurrent::clear() { + for (int32_t i = 0; i < (int32_t) size; ++i) { + cells[i].pos = -1; + cells[i].seq_id.clear(); + cells[i].src = -1; + cells[i].tail = -1; + } + head = 0; + used = 0; + + for (auto & buf : bufs) { + ggml_backend_buffer_clear(buf.get(), 0); + } +} + +bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + uint32_t new_head = size; + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // models like Mamba or RWKV can't have a state partially erased + if (seq_id >= (int64_t) size) { + // could be fatal + return false; + } + if (0 <= seq_id) { + int32_t & tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + const kv_cell & cell = cells[tail_id]; + // partial intersection is invalid + if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { + return false; + } + // invalidate tails which will be cleared + if (p0 <= cell.pos && cell.pos < p1) { + tail_id = -1; + } + } + } else { + // seq_id is negative, then the range should include everything or nothing + if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { + return false; + } + } + + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].pos >= p0 && cells[i].pos < p1) { + if (seq_id < 0) { + cells[i].seq_id.clear(); + } else if (cells[i].has_seq_id(seq_id)) { + cells[i].seq_id.erase(seq_id); + } else { + continue; + } + if (cells[i].is_empty()) { + // keep count of the number of used cells + if (cells[i].pos >= 0) { + used--; + } + cells[i].pos = -1; + cells[i].src = -1; + if (new_head == size) { + new_head = i; + } + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != size && new_head < head) { + head = new_head; + } + + return true; +} + +void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + if (seq_id_src == seq_id_dst) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) { + kv_cell & tail_src = cells[seq_id_src]; + kv_cell & tail_dst = cells[seq_id_dst]; + if (tail_dst.tail >= 0) { + // clear destination seq_id if it wasn't empty + kv_cell & cell_dst = cells[tail_dst.tail]; + + cell_dst.seq_id.erase(seq_id_dst); + tail_dst.tail = -1; + if (cell_dst.seq_id.empty()) { + cell_dst.pos = -1; + cell_dst.src = -1; + used -= 1; + } + } + if (tail_src.tail >= 0) { + kv_cell & cell_src = cells[tail_src.tail]; + + cell_src.seq_id.insert(seq_id_dst); + tail_dst.tail = tail_src.tail; + } + } +} + +void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) { + uint32_t new_head = size; + + for (uint32_t i = 0; i < size; ++i) { + if ((llama_seq_id) i != seq_id) { + cells[i].tail = -1; + } + + if (!cells[i].has_seq_id(seq_id)) { + if (cells[i].pos >= 0) { + used--; + } + + cells[i].pos = -1; + cells[i].src = -1; + cells[i].seq_id.clear(); + + if (new_head == size){ + new_head = i; + } + } else { + cells[i].seq_id.clear(); + cells[i].seq_id.insert(seq_id); + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != size && new_head < head) { + head = new_head; + } +} + +void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + if (shift == 0) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over the + if (p0 == p1) { + return; + } + + // for Mamba-like or RWKV models, only the pos needs to be shifted + if (0 <= seq_id && seq_id < (int64_t) size) { + const int32_t tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + kv_cell & cell = cells[tail_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos += shift; + } + } + } +} + +void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (d == 1) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over the cache. + if (p0 == p1) { + return; + } + + // for Mamba-like or RWKV models, only the pos needs to be changed + if (0 <= seq_id && seq_id < (int64_t) size) { + const int32_t tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + kv_cell & cell = cells[tail_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos /= d; + } + } + } +} + +llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const { + llama_pos result = std::numeric_limits::max(); + + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id)) { + result = std::min(result, cells[i].pos); + } + } + + if (result == std::numeric_limits::max()) { + result = -1; + } + + return result; +} + +llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const { + llama_pos result = -1; + + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id)) { + result = std::max(result, cells[i].pos); + } + } + + return result; +} + +llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { + GGML_UNUSED(embd_pooled); + + auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all); + + std::vector ubatches; + + while (sbatch.n_tokens > 0) { + llama_ubatch ubatch; + + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + ubatch = sbatch.split_seq(n_ubatch); + } else { + ubatch = sbatch.split_equal(n_ubatch); + } + + ubatches.push_back(ubatch); + } + + if (!prepare(ubatches)) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches)); +} + +llama_memory_state_ptr llama_kv_cache_recurrent::init_full() { + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); +} + +bool llama_kv_cache_recurrent::prepare(const std::vector & ubatches) { + // simply remember the full state because it is very small for this type of cache + // TODO: optimize + auto org_cells = cells; + auto org_used = used; + auto org_head = head; + + bool success = true; + + // TODO: here we have to verify that all ubatches can fit in the cells + // however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells + // during the compute of each ubatch. to reproduce, uncomment the following loop and run: + // + // $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8 + // + // recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed + // + GGML_UNUSED(ubatches); + //for (const auto & ubatch : ubatches) { + // if (!find_slot(ubatch)) { + // success = false; + // break; + // } + //} + + // restore the original state + cells = std::move(org_cells); + used = org_used; + head = org_head; + + return success; +} + +bool llama_kv_cache_recurrent::update(llama_context & lctx) { + GGML_UNUSED(lctx); + // noop + return false; +} + +void llama_kv_cache_recurrent::defrag_sched(float thold) { + GGML_UNUSED(thold); + // noop +} + +bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { + const uint32_t n_tokens = ubatch.n_tokens; + const uint32_t n_seqs = ubatch.n_seqs; + + const uint32_t n_seq_tokens = ubatch.n_seq_tokens; + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head > used + 2*n_tokens) { + head = 0; + } + + // For recurrent state architectures (like Mamba or RWKV), + // each cache cell can store the state for a whole sequence. + // A slot should be always be contiguous. + + // can only process batches with an equal number of new tokens in each sequence + GGML_ASSERT(ubatch.equal_seqs); + + int32_t min = size - 1; + int32_t max = 0; + + // everything should fit if all seq_ids are smaller than the max + for (uint32_t s = 0; s < n_seqs; ++s) { + const uint32_t n_seq_id = ubatch.n_seq_id[s]; + for (uint32_t j = 0; j < n_seq_id; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; + + if (seq_id < 0 || (uint32_t) seq_id >= size) { + // too big seq_id + // TODO: would it be possible to resize the cache instead? + LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max); + return false; + } + if (j > 0) { + kv_cell & seq = cells[seq_id]; + if (seq.tail >= 0) { + kv_cell & cell = cells[seq.tail]; + // clear cells from seq_ids that become shared + // (should not normally happen, but let's handle it anyway) + cell.seq_id.erase(seq_id); + seq.tail = -1; + if (cell.seq_id.empty()) { + cell.pos = -1; + cell.src = -1; + used -= 1; + } + } + } + } + } + +#ifndef NDEBUG + { + std::vector tails_verif; + tails_verif.assign(size, -1); + for (uint32_t i = 0; i < size; ++i) { + kv_cell & cell = cells[i]; + for (llama_seq_id seq_id : cell.seq_id) { + if (tails_verif[seq_id] != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]); + } + tails_verif[seq_id] = i; + } + } + for (uint32_t i = 0; i < size; ++i) { + if (tails_verif[i] != cells[i].tail) { + LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]); + } + } + } +#endif + + // find next empty cell + uint32_t next_empty_cell = head; + + for (uint32_t i = 0; i < size; ++i) { + if (next_empty_cell >= size) { next_empty_cell -= size; } + kv_cell & cell = cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; + } + + // find usable cell range + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + kv_cell & seq_meta = cells[seq_id]; + bool has_cell = false; + if (seq_meta.tail >= 0) { + kv_cell & cell = cells[seq_meta.tail]; + GGML_ASSERT(cell.has_seq_id(seq_id)); + // does this seq_id "own" the cell? + if (cell.seq_id.size() == 1) { has_cell = true; } + } + if (!has_cell) { + kv_cell & empty_cell = cells[next_empty_cell]; + GGML_ASSERT(empty_cell.is_empty()); + // copy old tail into the empty cell + if (seq_meta.tail >= 0) { + kv_cell & orig_cell = cells[seq_meta.tail]; + empty_cell.pos = orig_cell.pos; + empty_cell.src = orig_cell.src; + orig_cell.seq_id.erase(seq_id); + empty_cell.seq_id.insert(seq_id); // will be overwritten + } + seq_meta.tail = next_empty_cell; + // find next empty cell + if (s + 1 < n_seqs) { + next_empty_cell += 1; + for (uint32_t i = 0; i < size; ++i) { + if (next_empty_cell >= size) { next_empty_cell -= size; } + kv_cell & cell = cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; + } + } + } + if (min > seq_meta.tail) { min = seq_meta.tail; } + if (max < seq_meta.tail) { max = seq_meta.tail; } + } + + // gather and re-order + for (uint32_t s = 0; s < n_seqs; ++s) { + int32_t dst_id = s + min; + int32_t src_id = cells[ubatch.seq_id[s][0]].tail; + if (dst_id != src_id) { + kv_cell & dst_cell = cells[dst_id]; + kv_cell & src_cell = cells[src_id]; + + std::swap(dst_cell.pos, src_cell.pos); + std::swap(dst_cell.src, src_cell.src); + std::swap(dst_cell.seq_id, src_cell.seq_id); + + // swap tails (assuming they NEVER overlap) + for (const llama_seq_id seq_id : src_cell.seq_id) { + cells[seq_id].tail = src_id; + } + for (const llama_seq_id seq_id : dst_cell.seq_id) { + cells[seq_id].tail = dst_id; + } + } + } + + // update the pos of the used seqs + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; + int32_t cell_id = s + min; + kv_cell & cell = cells[cell_id]; + + if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { + // What should happen when the pos backtracks or skips a value? + // Clearing the state mid-batch would require special-casing which isn't done. + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", + __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens); + } + cell.pos = last_pos; + cell.seq_id.clear(); + for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; + cell.seq_id.insert(seq_id); + cells[seq_id].tail = cell_id; + } + } + + // allow getting the range of used cells, from head to head + n + head = min; + n = max - min + 1; + used = std::count_if(cells.begin(), cells.end(), + [](const kv_cell & cell){ return !cell.is_empty(); }); + + // sanity check + return n >= n_seqs; +} + +bool llama_kv_cache_recurrent::get_can_shift() const { + return false; +} + +int32_t llama_kv_cache_recurrent::s_copy(int i) const { + const uint32_t cell_id = i + head; + + ////////////////////////////////////////////// + // TODO: this should not mutate the KV cache ! + kv_cell & cell = const_cast(cells[cell_id]); + + // prevent out-of-bound sources + if (cell.src < 0 || (uint32_t) cell.src >= size) { + cell.src = cell_id; + } + + int32_t res = cell.src; + + // TODO: do not mutate the KV cache + // ensure copy only happens once + if (cell.src != (int32_t) cell_id) { + cell.src = cell_id; + } + + return res; +} + +float llama_kv_cache_recurrent::s_mask(int i) const { + const uint32_t cell_id = i + head; + + ////////////////////////////////////////////// + // TODO: this should not mutate the KV cache ! + kv_cell & cell = const_cast(cells[cell_id]); + + float res = (float) (cell.src >= 0); + + // only clear once + if (cell.src < 0) { + cell.src = cell_id; + } + + return res; +} + +size_t llama_kv_cache_recurrent::total_size() const { + size_t size = 0; + for (const auto & buf : bufs) { + size += ggml_backend_buffer_get_size(buf.get()); + } + + return size; +} + +size_t llama_kv_cache_recurrent::size_k_bytes() const { + size_t size_k_bytes = 0; + + for (const auto & k : k_l) { + size_k_bytes += ggml_nbytes(k); + } + + return size_k_bytes; +} + +size_t llama_kv_cache_recurrent::size_v_bytes() const { + size_t size_v_bytes = 0; + + for (const auto & v : v_l) { + size_v_bytes += ggml_nbytes(v); + } + + return size_v_bytes; +} + +void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + std::vector> cell_ranges; // ranges, from inclusive, to exclusive + uint32_t cell_count = 0; + + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id (or all, when -1) + uint32_t cell_range_begin = size; + for (uint32_t i = 0; i < size; ++i) { + const auto & cell = cells[i]; + if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { + ++cell_count; + if (cell_range_begin == size) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != size) { + cell_ranges.emplace_back(cell_range_begin, i); + cell_range_begin = size; + } + } + } + if (cell_range_begin != size) { + cell_ranges.emplace_back(cell_range_begin, size); + } + + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + + io.write(&cell_count, sizeof(cell_count)); + + state_write_meta(io, cell_ranges, seq_id); + state_write_data(io, cell_ranges); +} + +void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + uint32_t cell_count; + io.read_to(&cell_count, sizeof(cell_count)); + + bool res = true; + + res = res && state_read_meta(io, cell_count, seq_id); + res = res && state_read_data(io, cell_count); + + if (!res) { + if (seq_id == -1) { + clear(); + } else { + seq_rm(seq_id, -1, -1); + } + throw std::runtime_error("failed to restore kv cache"); + } +} + +void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { + for (const auto & range : cell_ranges) { + for (uint32_t i = range.first; i < range.second; ++i) { + const auto & cell = cells[i]; + const llama_pos pos = cell.pos; + const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0; + + io.write(&pos, sizeof(pos)); + io.write(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id) { + for (auto seq_id : cell.seq_id) { + io.write(&seq_id, sizeof(seq_id)); + } + } + } + } +} + +void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { + const uint32_t v_trans = 0; + const uint32_t n_layer = hparams.n_layer; + + io.write(&v_trans, sizeof(v_trans)); + io.write(&n_layer, sizeof(n_layer)); + + std::vector tmp_buf; + + // Iterate and write all the keys first, each row is a cell + // Get whole range at a time + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Write key type + const int32_t k_type_i = (int32_t)k_l[il]->type; + io.write(&k_type_i, sizeof(k_type_i)); + + // Write row size of key + const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + io.write(&k_size_row, sizeof(k_size_row)); + + // Read each range of cells of k_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * k_size_row; + io.write_tensor(k_l[il], range.first * k_size_row, buf_size); + } + } + + if (!v_trans) { + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)v_l[il]->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write row size of value + const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + io.write(&v_size_row, sizeof(v_size_row)); + + // Read each range of cells of v_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * v_size_row; + io.write_tensor(v_l[il], range.first * v_size_row, buf_size); + } + } + } else { + // When v is transposed, we also need the element size and get the element ranges from each row + const uint32_t kv_size = size; + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)v_l[il]->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write element size + const uint32_t v_size_el = ggml_type_size(v_l[il]->type); + io.write(&v_size_el, sizeof(v_size_el)); + + // Write GQA embedding size + io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); + + // For each row, we get the element values of each cell + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + // Read each range of cells of v_size_el length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t src_offset = (range.first + j * kv_size) * v_size_el; + const size_t buf_size = range_size * v_size_el; + io.write_tensor(v_l[il], src_offset, buf_size); + } + } + } + } +} + +bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { + if (dest_seq_id != -1) { + // single sequence + + seq_rm(dest_seq_id, -1, -1); + + llama_sbatch sbatch; + llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); + + batch.n_tokens = cell_count; + batch.n_seq_tokens = cell_count; + batch.n_seqs = 1; + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id != 0) { + LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); + return false; + } + + batch.pos[i] = pos; + } + batch.n_seq_id[0] = 1; + batch.seq_id[0] = &dest_seq_id; + + if (!find_slot(batch)) { + LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); + return false; + } + + // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) + // Assume that this is one contiguous block of cells + GGML_ASSERT(head + cell_count <= size); + GGML_ASSERT(cells[head].pos == batch.pos[0]); + GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]); + GGML_ASSERT(cells[head].has_seq_id(dest_seq_id)); + GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id)); + } else { + // whole KV cache restore + + if (cell_count > size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); + return false; + } + + clear(); + + for (uint32_t i = 0; i < cell_count; ++i) { + kv_cell & cell = cells[i]; + + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + cell.pos = pos; + + for (uint32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id; + io.read_to(&seq_id, sizeof(seq_id)); + + // TODO: llama_kv_cache_recurrent should have a notion of max sequences + //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { + if (seq_id < 0) { + //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id); + return false; + } + + cell.seq_id.insert(seq_id); + + int32_t & tail = cells[seq_id].tail; + if (tail != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); + return false; + } + tail = i; + } + } + + head = 0; + used = cell_count; + } + + for (uint32_t i = 0; i < cell_count; ++i) { + uint32_t cell_id = head + i; + // make sure the recurrent states will keep their restored state + cells[cell_id].src = cell_id; + } + + return true; +} + +bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) { + uint32_t v_trans; + uint32_t n_layer; + io.read_to(&v_trans, sizeof(v_trans)); + io.read_to(&n_layer, sizeof(n_layer)); + + if (n_layer != hparams.n_layer) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); + return false; + } + if (cell_count > size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size); + return false; + } + if (false != (bool) v_trans) { + LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); + return false; + } + + // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Read type of key + int32_t k_type_i_ref; + io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); + const int32_t k_type_i = (int32_t) k_l[il]->type; + if (k_type_i != k_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); + return false; + } + + // Read row size of key + uint64_t k_size_row_ref; + io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); + const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + if (k_size_row != k_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); + } + } + + if (!v_trans) { + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)v_l[il]->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read row size of value + uint64_t v_size_row_ref; + io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); + const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + if (v_size_row != v_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the values for the whole cell range + ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); + } + } + } else { + // For each layer, read the values for each cell (transposed) + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)v_l[il]->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read element size of value + uint32_t v_size_el_ref; + io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); + const size_t v_size_el = ggml_type_size(v_l[il]->type); + if (v_size_el != v_size_el_ref) { + LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); + return false; + } + + // Read GQA embedding size + uint32_t n_embd_v_gqa_ref; + io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); + if (n_embd_v_gqa != n_embd_v_gqa_ref) { + LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); + return false; + } + + if (cell_count) { + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (head + j * size) * v_size_el; + ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + } + } + } + } + + return true; +} + +// +// llama_kv_cache_recurrent_state +// + +llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(llama_memory_status status) : status(status) {} + +llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) { +} + +llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv, + llama_sbatch sbatch, + std::vector ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {} + +llama_kv_cache_recurrent_state::~llama_kv_cache_recurrent_state() = default; + +bool llama_kv_cache_recurrent_state::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_recurrent_state::apply() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + kv->find_slot(ubatches[i_next]); + + return true; +} + +std::vector & llama_kv_cache_recurrent_state::out_ids() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return sbatch.out_ids; +} + +llama_memory_status llama_kv_cache_recurrent_state::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_recurrent_state::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ubatches[i_next]; +} + +uint32_t llama_kv_cache_recurrent_state::get_n_kv() const { + return is_full ? kv->size : kv->n; +} + +uint32_t llama_kv_cache_recurrent_state::get_head() const { + return is_full ? 0 : kv->head; +} + +uint32_t llama_kv_cache_recurrent_state::get_size() const { + return kv->size; +} + +ggml_tensor * llama_kv_cache_recurrent_state::get_k_l(int32_t il) const { + return kv->k_l[il]; +} + +ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const { + return kv->v_l[il]; +} + +int32_t llama_kv_cache_recurrent_state::s_copy(int i) const { + return kv->s_copy(i); +} + +float llama_kv_cache_recurrent_state::s_mask(int i) const { + return kv->s_mask(i); +} diff --git a/src/llama-kv-cache-recurrent.h b/src/llama-kv-cache-recurrent.h new file mode 100644 index 000000000..a178ae85c --- /dev/null +++ b/src/llama-kv-cache-recurrent.h @@ -0,0 +1,191 @@ +#pragma once + +#include "llama-batch.h" +#include "llama-graph.h" +#include "llama-kv-cache.h" + +#include +#include + +// +// llama_kv_cache_recurrent +// + +// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i +// see the implementation of llama_kv_cache_unified_state_i for an example how to do it +class llama_kv_cache_recurrent : public llama_kv_cache { +public: + llama_kv_cache_recurrent( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max); + + ~llama_kv_cache_recurrent() = default; + + // + // llama_memory_i + // + + void clear() override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + // + // llama_kv_cache + // + + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) override; + + llama_memory_state_ptr init_full() override; + + bool update(llama_context & lctx) override; + + void defrag_sched(float thold) override; + + bool prepare(const std::vector & ubatches); + + // find a contiguous slot of kv cells and emplace the ubatch there + bool find_slot(const llama_ubatch & ubatch); + + bool get_can_shift() const override; + + // TODO: temporary methods - they are not really const as they do const_cast<>, fix this + int32_t s_copy(int i) const; + float s_mask(int i) const; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + + uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) + uint32_t size = 0; // total number of cells, shared across all sequences + uint32_t used = 0; // used cells (i.e. at least one seq_id) + + // computed before each graph build + uint32_t n = 0; + + // TODO: optimize for recurrent state needs + struct kv_cell { + llama_pos pos = -1; + int32_t src = -1; // used to copy states + int32_t tail = -1; + + std::set seq_id; + + bool has_seq_id(const llama_seq_id & id) const { + return seq_id.find(id) != seq_id.end(); + } + + bool is_empty() const { + return seq_id.empty(); + } + + bool is_same_seq(const kv_cell & other) const { + return seq_id == other.seq_id; + } + }; + + std::vector cells; + + std::vector k_l; // per layer + std::vector v_l; + +private: + //const llama_model & model; + const llama_hparams & hparams; + + const uint32_t n_seq_max = 1; + + std::vector ctxs; + std::vector bufs; + + size_t total_size() const; + + size_t size_k_bytes() const; + size_t size_v_bytes() const; + + void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; + void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; + + bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); + bool state_read_data(llama_io_read_i & io, uint32_t cell_count); +}; + +class llama_kv_cache_recurrent_state : public llama_memory_state_i { +public: + // used for errors + llama_kv_cache_recurrent_state(llama_memory_status status); + + // used to create a full-cache state + llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv); + + // used to create a state from a batch + llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv, + llama_sbatch sbatch, + std::vector ubatches); + + virtual ~llama_kv_cache_recurrent_state(); + + // + // llama_memory_state_i + // + + bool next() override; + bool apply() override; + + std::vector & out_ids() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_recurrent_state specific API + // + + uint32_t get_n_kv() const; + uint32_t get_head() const; + uint32_t get_size() const; + + ggml_tensor * get_k_l(int32_t il) const; + ggml_tensor * get_v_l(int32_t il) const; + + int32_t s_copy(int i) const; + float s_mask(int i) const; + +private: + const llama_memory_status status; + + llama_kv_cache_recurrent * kv; + + llama_sbatch sbatch; + + size_t i_next = 0; + + std::vector ubatches; + + // + // data needed for building the compute graph for the current ubatch: + // TODO: extract all the state like `head` and `n` here + // + + const bool is_full = false; +}; diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp new file mode 100644 index 000000000..0eb045634 --- /dev/null +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -0,0 +1,249 @@ +#include "llama-kv-cache-unified-iswa.h" + +#include "llama-impl.h" +#include "llama-batch.h" +#include "llama-model.h" + +#include +#include + +// +// llama_kv_cache_unified_iswa +// + +llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool swa_full, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_ubatch, + uint32_t n_pad) : hparams(model.hparams) { + llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); }; + llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); }; + + const uint32_t size_base = kv_size; + + uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad)); + + // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size + if (swa_full) { + LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n", + __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); + + size_swa = size_base; + } + + LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base); + + kv_base = std::make_unique( + model, std::move(filter_base), type_k, type_v, + v_trans, offload, size_base, n_seq_max, n_pad, + 0, LLAMA_SWA_TYPE_NONE); + + LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); + + kv_swa = std::make_unique( + model, std::move(filter_swa), type_k, type_v, + v_trans, offload, size_swa, n_seq_max, n_pad, + hparams.n_swa, hparams.swa_type); +} + +void llama_kv_cache_unified_iswa::clear() { + kv_base->clear(); + kv_swa ->clear(); +} + +bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + bool res = true; + + res = res & kv_base->seq_rm(seq_id, p0, p1); + res = res & kv_swa ->seq_rm(seq_id, p0, p1); + + return res; +} + +void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1); + kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) { + kv_base->seq_keep(seq_id); + kv_swa ->seq_keep(seq_id); +} + +void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + kv_base->seq_add(seq_id, p0, p1, shift); + kv_swa ->seq_add(seq_id, p0, p1, shift); +} + +void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + kv_base->seq_div(seq_id, p0, p1, d); + kv_swa ->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const { + // the base cache is a superset of the SWA cache, so we can just check the SWA cache + return kv_swa->seq_pos_min(seq_id); +} + +llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const { + return kv_swa->seq_pos_max(seq_id); +} + +llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { + GGML_UNUSED(embd_pooled); + + // TODO: if we fail with split_simple, we should attempt different splitting strategies + // but to do that properly, we first have to refactor the batches to be more flexible + + auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); + + std::vector ubatches; + + while (sbatch.n_tokens > 0) { + auto ubatch = sbatch.split_simple(n_ubatch); + + ubatches.push_back(ubatch); + } + + auto heads_base = kv_base->prepare(ubatches); + if (heads_base.empty()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + auto heads_swa = kv_swa->prepare(ubatches); + if (heads_swa.empty()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + assert(heads_base.size() == heads_swa.size()); + + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, + this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches)); +} + +llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() { + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); +} + +bool llama_kv_cache_unified_iswa::update(llama_context & lctx) { + bool res = false; + + res = res | kv_base->update(lctx); + res = res | kv_swa ->update(lctx); + + return res; +} + +void llama_kv_cache_unified_iswa::defrag_sched(float thold) { + kv_base->defrag_sched(thold); + kv_swa ->defrag_sched(thold); +} + +bool llama_kv_cache_unified_iswa::get_can_shift() const { + return kv_base->get_size() == kv_swa->get_size(); +} + +void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + kv_base->state_write(io, seq_id); + kv_swa ->state_write(io, seq_id); +} + +void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + kv_base->state_read(io, seq_id); + kv_swa ->state_read(io, seq_id); +} + +llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const { + return kv_base.get(); +} + +llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const { + return kv_swa.get(); +} + +// +// llama_kv_cache_unified_iswa_state +// + +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {} + +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( + llama_memory_status status, + llama_kv_cache_unified_iswa * kv) : status(status) { + state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base())); + state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ())); +} + +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( + llama_memory_status status, + llama_kv_cache_unified_iswa * kv, + llama_sbatch sbatch, + std::vector heads_base, + std::vector heads_swa, + std::vector ubatches) + : status(status), + sbatch(std::move(sbatch)), + ubatches(std::move(ubatches)) { + // note: here we copy the ubatches. not sure if this is ideal + state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches)); + state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches)); + } + +llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default; + +bool llama_kv_cache_unified_iswa_state::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + state_base->next(); + state_swa ->next(); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_unified_iswa_state::apply() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + bool res = true; + + res = res & state_base->apply(); + res = res & state_swa ->apply(); + + return res; +} + +std::vector & llama_kv_cache_unified_iswa_state::out_ids() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return sbatch.out_ids; +} + +llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + return ubatches[i_next]; +} + +const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return state_base.get(); +} + +const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return state_swa.get(); +} diff --git a/src/llama-kv-cache-unified-iswa.h b/src/llama-kv-cache-unified-iswa.h new file mode 100644 index 000000000..8b067da03 --- /dev/null +++ b/src/llama-kv-cache-unified-iswa.h @@ -0,0 +1,136 @@ +#pragma once + +#include "llama-kv-cache-unified.h" + +#include + +// +// llama_kv_cache_unified_iswa +// + +// utilizes two instances of llama_kv_cache_unified +// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers + +class llama_kv_cache_unified_iswa : public llama_kv_cache { +public: + llama_kv_cache_unified_iswa( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool swa_full, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_ubatch, + uint32_t n_pad); + + ~llama_kv_cache_unified_iswa() = default; + + // + // llama_memory_i + // + + void clear() override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + // + // llama_kv_cache + // + + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) override; + + llama_memory_state_ptr init_full() override; + + bool update(llama_context & lctx) override; + + void defrag_sched(float thold) override; + + bool get_can_shift() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + + // + // llama_kv_cache_unified_iswa specific API + // + + llama_kv_cache_unified * get_base() const; + llama_kv_cache_unified * get_swa () const; + +private: + const llama_hparams & hparams; + + std::unique_ptr kv_base; + std::unique_ptr kv_swa; +}; + +class llama_kv_cache_unified_iswa_state : public llama_memory_state_i { +public: + // used for errors + llama_kv_cache_unified_iswa_state(llama_memory_status status); + + // used to create a full-cache state + llama_kv_cache_unified_iswa_state( + llama_memory_status status, + llama_kv_cache_unified_iswa * kv); + + // used to create a state from a batch + llama_kv_cache_unified_iswa_state( + llama_memory_status status, + llama_kv_cache_unified_iswa * kv, + llama_sbatch sbatch, + std::vector heads_base, + std::vector heads_swa, + std::vector ubatches); + + virtual ~llama_kv_cache_unified_iswa_state(); + + // + // llama_memory_state_i + // + + bool next() override; + bool apply() override; + + std::vector & out_ids() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_unified_iswa_state specific API + // + + const llama_kv_cache_unified_state * get_base() const; + const llama_kv_cache_unified_state * get_swa() const; + +private: + const llama_memory_status status; + + //llama_kv_cache_unified_iswa * kv; + + llama_sbatch sbatch; + + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector ubatches; + + std::unique_ptr state_base; + std::unique_ptr state_swa; +}; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp new file mode 100644 index 000000000..a81715476 --- /dev/null +++ b/src/llama-kv-cache-unified.cpp @@ -0,0 +1,1717 @@ +#include "llama-kv-cache-unified.h" + +#include "llama-impl.h" +#include "llama-model.h" +#include "llama-context.h" + +#include +#include +#include +#include +#include +#include + +// +// llama_kv_cache_unified +// + +llama_kv_cache_unified::llama_kv_cache_unified( + const llama_model & model, + layer_filter_cb && filter, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type) : + model(model), hparams(model.hparams), v_trans(v_trans), + n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { + + GGML_ASSERT(kv_size % n_pad == 0); + + // create a context for each buffer type + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return nullptr; + } + + ctx_map[buft] = ctx; + ctxs.emplace_back(ctx); + + return ctx; + } + + return it->second; + }; + + head = 0; + + cells.resize(kv_size); + + for (uint32_t il = 0; il < hparams.n_layer; il++) { + if (filter && !filter(il)) { + LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il); + continue; + } + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + const char * dev_name = "CPU"; + + ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); + + if (offload) { + auto * dev = model.dev_layer(il); + buft = ggml_backend_dev_buffer_type(dev); + + dev_name = ggml_backend_dev_name(dev); + } + + LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name); + + ggml_context * ctx = ctx_for_buft(buft); + if (!ctx) { + throw std::runtime_error("failed to create ggml context for kv cache"); + } + + ggml_tensor * k; + ggml_tensor * v; + + k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size); + v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size); + + ggml_format_name(k, "cache_k_l%d", il); + ggml_format_name(v, "cache_v_l%d", il); + + map_layer_ids[il] = layers.size(); + layers.push_back({ il, k, v }); + } + + // allocate tensors and initialize the buffers to avoid NaNs in the padding + for (auto it : ctx_map) { + auto * buft = it.first; + auto * ctx = it.second; + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + throw std::runtime_error("failed to allocate buffer for kv cache"); + } + + LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + + ggml_backend_buffer_clear(buf, 0); + bufs.emplace_back(buf); + } + + { + const size_t memory_size_k = size_k_bytes(); + const size_t memory_size_v = size_v_bytes(); + + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } +} + +void llama_kv_cache_unified::clear() { + cells.reset(); + + head = 0; + + for (auto & buf : bufs) { + ggml_backend_buffer_clear(buf.get(), 0); + } +} + +bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + uint32_t new_head = cells.size(); + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) { + if (new_head == cells.size()) { + new_head = i; + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cells.size() && new_head < head) { + head = new_head; + } + + return true; +} + +void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + if (seq_id_src == seq_id_dst) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id_src)) { + cells.seq_add(i, seq_id_dst); + } + } +} + +void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { + uint32_t new_head = cells.size(); + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (cells.seq_keep(i, seq_id)) { + if (new_head == cells.size()) { + new_head = i; + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cells.size() && new_head < head) { + head = new_head; + } +} + +void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + if (shift == 0) { + return; + } + + uint32_t new_head = cells.size(); + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over all cells. + if (p0 == p1) { + return; + } + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id)) { + if (cells.pos_add(i, shift)) { + if (new_head == cells.size()) { + new_head = i; + } + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + // Otherwise we just start the next search from the beginning. + head = new_head != cells.size() ? new_head : 0; +} + +void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (d == 1) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over the cache. + if (p0 == p1) { + return; + } + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id)) { + cells.pos_div(i, d); + } + } +} + +llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { + return cells.seq_pos_min(seq_id); +} + +llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { + return cells.seq_pos_max(seq_id); +} + +llama_memory_state_ptr llama_kv_cache_unified::init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) { + GGML_UNUSED(embd_pooled); + + auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); + + std::vector ubatches; + while (sbatch.n_tokens > 0) { + ubatches.push_back(sbatch.split_simple(n_ubatch)); + } + + auto heads = prepare(ubatches); + if (heads.empty()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, + this, std::move(sbatch), std::move(heads), std::move(ubatches)); +} + +llama_memory_state_ptr llama_kv_cache_unified::init_full() { + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); +} + +std::vector llama_kv_cache_unified::prepare(const std::vector & ubatches) { + std::vector res; + + struct state { + uint32_t head_old; // old position of the head, before placing the ubatch + uint32_t head_new; // new position of the head, after placing the ubatch + + llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch + }; + + // remember the old state of the cells so we can restore it in the end + std::vector states; + + bool success = true; + + for (const auto & ubatch : ubatches) { + // only find a suitable slot for the ubatch. don't modify the cells yet + const int32_t head_new = find_slot(ubatch); + if (head_new < 0) { + success = false; + break; + } + + // remeber the position that we found + res.push_back(head_new); + + // store the old state of the cells in the recovery stack + states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)}); + + // now emplace the ubatch + apply_ubatch(head_new, ubatch); + } + + // iterate backwards and restore the cells to their original state + for (auto it = states.rbegin(); it != states.rend(); ++it) { + cells.set(it->head_new, it->cells); + head = it->head_old; + } + + if (!success) { + return {}; + } + + return res; +} + +bool llama_kv_cache_unified::update(llama_context & lctx) { + bool updated = false; + + auto * sched = lctx.get_sched(); + + if (cells.get_has_shift()) { + if (!get_can_shift()) { + GGML_ABORT("The current KV cache / model configuration does not support K-shift"); + } + + LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__); + + // apply K-shift if needed + if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { + ggml_backend_sched_reset(sched); + + auto * gf = lctx.graph_init(); + + auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf); + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__); + return updated; + } + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__); + return updated; + } + + res->set_inputs(nullptr); + + if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__); + return updated; + } + + updated = true; + } + + cells.reset_shift(); + } + + if (do_defrag) { + LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); + + if (defrag_prepare(lctx.graph_max_nodes())) { + ggml_backend_sched_reset(sched); + + auto * gf = lctx.graph_init(); + + auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf); + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__); + return updated; + } + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__); + return updated; + } + + res->set_inputs(nullptr); + + if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__); + return updated; + } + + updated = true; + } + + do_defrag = false; + } + + return updated; +} + +void llama_kv_cache_unified::defrag_sched(float thold) { + const auto n_kv = cells.used_max_p1(); + + // - do not defrag small contexts (i.e. < 2048 tokens) + // - count the padding towards the number of used tokens + const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f; + + // queue defragmentation for next llama_kv_cache_update + if (fragmentation > thold) { + LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); + + do_defrag = true; + } +} + +int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { + const uint32_t n_tokens = ubatch.n_tokens; + + uint32_t head_cur = this->head; + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head_cur > cells.get_used() + 2*ubatch.n_tokens) { + head_cur = 0; + } + + // otherwise, one cell per token. + + if (n_tokens > cells.size()) { + LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); + return -1; + } + +//#define FIND_SLOT_DEBUG 1 +#if FIND_SLOT_DEBUG + LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", cells.used_max_p1(), cells.get_used(), head, n_swa); + + // for debugging + { + std::string ss; + if (n_swa > 0) { + for (uint32_t i = 0; i < cells.size(); ++i) { + if (cells.is_empty(i)) { + ss += '.'; + } else { + ss += std::to_string(cells.seq_get(i)); + } + if (i%256 == 255) { + ss += '\n'; + } + } + } + LLAMA_LOG_WARN("\n%s\n", ss.c_str()); + } + + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (cells.seq_pos_min(s) < 0) { + continue; + } + + LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[%d] = %5d, max[%d] = %5d\n", n_swa, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s)); + } +#endif + + uint32_t n_tested = 0; + + while (true) { + if (head_cur + n_tokens > cells.size()) { + n_tested += cells.size() - head_cur; + head_cur = 0; + continue; + } + + // keep track of what the minimum sequence positions would be if we accept the ubatch + llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES]; + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + seq_pos_min[s] = cells.seq_pos_min(s); + } + + bool found = true; + for (uint32_t i = 0; i < n_tokens; i++) { + const llama_pos pos = ubatch.pos[i]; + const llama_seq_id seq_id = ubatch.seq_id[i][0]; + + // can we use this cell? either: + // - the cell is empty + // - the cell is occupied only by one sequence: + // - mask causally, if the sequence is the same as the one we are inserting + // - mask SWA, using current max pos for that sequence in the cache + // always insert in the cell with minimum pos + bool can_use = cells.is_empty(head_cur + i); + + if (!can_use && cells.seq_count(head_cur + i) == 1) { + const llama_pos pos_cell = cells.pos_get(head_cur + i); + + // causal mask + if (cells.seq_has(head_cur + i, seq_id)) { + can_use = pos_cell >= pos; + } + + if (!can_use) { + const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i); + + // SWA mask + // note: we insert only in the cell with minimum pos in order to preserve the invariant that + // all positions between [pos_min, pos_max] for each sequence will be present in the cache + // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 + if (pos_cell == seq_pos_min[seq_id_cell] && + is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { + seq_pos_min[seq_id_cell]++; + can_use = true; + } + } + } + + if (!can_use) { + found = false; + head_cur += i + 1; + n_tested += i + 1; + break; + } + } + + if (found) { + break; + } + + if (n_tested >= cells.size()) { + //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return -1; + } + } + + return head_cur; +} + +void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) { + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + if (!cells.is_empty(head_cur + i)) { + cells.rm(head_cur + i); + } + + cells.pos_set(head_cur + i, ubatch.pos[i]); + + for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) { + cells.seq_add(head_cur + i, ubatch.seq_id[i][j]); + } + } + + // move the head at the end of the slot + head = head_cur + ubatch.n_tokens; +} + +bool llama_kv_cache_unified::get_can_shift() const { + return true; +} + +uint32_t llama_kv_cache_unified::get_size() const { + return cells.size(); +} + +uint32_t llama_kv_cache_unified::get_n_kv() const { + return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); +} + +ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const { + const int32_t ikv = map_layer_ids.at(il); + + auto * k = layers[ikv].k; + + return ggml_view_3d(ctx, k, + hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, + ggml_row_size(k->type, hparams.n_embd_head_k), + ggml_row_size(k->type, hparams.n_embd_k_gqa(il)), + 0); +} + +ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const { + const int32_t ikv = map_layer_ids.at(il); + + auto * v = layers[ikv].v; + + if (!v_trans) { + // note: v->nb[1] <= v->nb[2] + return ggml_view_3d(ctx, v, + hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, + ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2] + 0); + } + + // note: v->nb[1] > v->nb[2] + return ggml_view_3d(ctx, v, + n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, + ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, v->ne[1]), // v->nb[2] + 0); +} + +ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const { + const int32_t ikv = map_layer_ids.at(il); + + auto * k = layers[ikv].k; + + const int64_t n_tokens = k_cur->ne[2]; + + ggml_tensor * k_view = ggml_view_1d(ctx, k, + n_tokens*hparams.n_embd_k_gqa(il), + ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur); + + return ggml_cpy(ctx, k_cur, k_view); +} + +ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const { + const int32_t ikv = map_layer_ids.at(il); + + auto * v = layers[ikv].v; + + const int64_t n_tokens = v_cur->ne[2]; + + v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens); + + ggml_tensor * v_view = nullptr; + + if (!v_trans) { + v_view = ggml_view_1d(ctx, v, + n_tokens*hparams.n_embd_v_gqa(il), + ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur); + } else { + // note: the V cache is transposed when not using flash attention + v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il), + (v->ne[1])*ggml_element_size(v), + (head_cur)*ggml_element_size(v)); + + v_cur = ggml_transpose(ctx, v_cur); + } + + return ggml_cpy(ctx, v_cur, v_view); +} + +void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { + const int64_t n_tokens = ubatch->n_tokens; + const int64_t n_seq_tokens = ubatch->n_seq_tokens; + const int64_t n_seqs = ubatch->n_seqs; + + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + float * data = (float *) dst->data; + + const auto n_kv = dst->ne[0]; + + // Use only the previous KV cells of the correct sequence for each token of the ubatch. + // It's assumed that if a token in the batch has multiple sequences, they are equivalent. + // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch: + // Causal mask: + // xxx------- + // xxxx------ + // xxxxx----- + // Non-causal mask: + // xxxxx----- + // xxxxx----- + // xxxxx----- + // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 + for (int h = 0; h < 1; ++h) { + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[s][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j]; + + for (uint32_t i = 0; i < n_kv; ++i) { + float f = 0.0f; + + bool masked = false; + + if (cells.is_empty(i)) { + masked = true; + } else { + const llama_pos p0 = cells.pos_get(i); + + // mask the token if not the same sequence + masked = masked || (!cells.seq_has(i, seq_id)); + + // mask future tokens + masked = masked || (causal_attn && p0 > p1); + + // apply SWA if any + masked = masked || (is_masked_swa(p0, p1)); + + if (!masked && hparams.use_alibi) { + f = -std::abs(p0 - p1); + } + } + + if (masked) { + f = -INFINITY; + } + + data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; + } + } + } + + // mask padded tokens + if (data) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (uint32_t j = 0; j < n_kv; ++j) { + data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } + } + } + } +} + +void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + + int32_t * data = (int32_t *) dst->data; + + for (uint32_t i = 0; i < cells.size(); ++i) { + data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i); + } +} + +void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { + const int64_t n_tokens = ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing + + int32_t * data = (int32_t *) dst->data; + + const int32_t n_kv = dst->ne[0]; + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_kv; ++i) { + // the position when the cells is empty is irrelevant - it will be masked out later in the attention + const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i); + + data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false); + } + } + } +} + +size_t llama_kv_cache_unified::total_size() const { + size_t size = 0; + + for (const auto & buf : bufs) { + size += ggml_backend_buffer_get_size(buf.get()); + } + + return size; +} + +size_t llama_kv_cache_unified::size_k_bytes() const { + size_t size_k_bytes = 0; + + for (const auto & layer : layers) { + size_k_bytes += ggml_nbytes(layer.k); + } + + return size_k_bytes; +} + +size_t llama_kv_cache_unified::size_v_bytes() const { + size_t size_v_bytes = 0; + + for (const auto & layer : layers) { + size_v_bytes += ggml_nbytes(layer.v); + } + + return size_v_bytes; +} + +ggml_tensor * llama_kv_cache_unified::build_rope_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * shift, + ggml_tensor * factors, + float freq_base, + float freq_scale) const { + const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; + + const auto & yarn_ext_factor = cparams.yarn_ext_factor; + const auto & yarn_beta_fast = cparams.yarn_beta_fast; + const auto & yarn_beta_slow = cparams.yarn_beta_slow; + + const auto & n_rot = hparams.n_rot; + const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE + // @ngxson : this is a workaround + // for M-RoPE, we want to rotate the whole vector when doing KV shift + // a normal RoPE should work, we just need to use the correct ordering + // ref: https://github.com/ggml-org/llama.cpp/pull/13870 + ? LLAMA_ROPE_TYPE_NEOX + : hparams.rope_type; + + // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly. + // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. + const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 + ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) + : cparams.yarn_attn_factor; + + ggml_tensor * tmp; + + if (ggml_is_quantized(cur->type)) { + // dequantize to f32 -> RoPE -> quantize back + tmp = ggml_cast(ctx, cur, GGML_TYPE_F32); + + tmp = ggml_rope_ext(ctx, tmp, + shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + + tmp = ggml_cpy(ctx, tmp, cur); + } else { + // we rotate only the first n_rot dimensions + tmp = ggml_rope_ext_inplace(ctx, cur, + shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + } + + return tmp; +} + +class llm_graph_input_k_shift : public llm_graph_input_i { +public: + llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} + virtual ~llm_graph_input_k_shift() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * k_shift; // I32 [kv_size] + + const llama_kv_cache_unified * kv_self; +}; + +void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { + GGML_UNUSED(ubatch); + + if (k_shift) { + kv_self->set_input_k_shift(k_shift); + } +} + +llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const { + auto res = std::make_unique(); + + const auto & n_embd_head_k = hparams.n_embd_head_k; + //const auto & n_embd_head_v = hparams.n_embd_head_v; + + //GGML_ASSERT(kv_self->size == n_ctx); + + auto inp = std::make_unique(this); + + inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx); + ggml_set_input(inp->k_shift); + + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const int64_t n_head_kv = hparams.n_head_kv(il); + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + ggml_tensor * k = + ggml_view_3d(ctx, layer.k, + n_embd_head_k, n_head_kv, cells.size(), + ggml_row_size(layer.k->type, n_embd_head_k), + ggml_row_size(layer.k->type, n_embd_k_gqa), + 0); + + ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); + + ggml_build_forward_expand(gf, cur); + } + + res->add_input(std::move(inp)); + + return res; +} + +llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const { + auto res = std::make_unique(); + + const auto & ids = defrag_info.ids; + +#if 0 + // CPU defrag + // + // TODO: optimizations are possible: + // - multiple threads + // - avoid copying to the host memory when already there + // + // likely not worth the effort, as we have ggml_graph based defrag + // + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + + const uint32_t kv_size = size; + + std::vector buf_k; + std::vector buf_v; + + for (uint32_t il = 0; il < n_layer; ++il) { + const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size); + + const size_t v_size_el = ggml_type_size(v_l[il]->type); + const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size); + + buf_k.resize(k_size); + buf_v.resize(v_size); + + ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size()); + ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size()); + + // batch move [i, i+nm) to [id, id+nm) + // note: cells can move only to a lower index + for (uint32_t i = 0; i < n_kv; ++i) { + const uint32_t id = ids[i]; + + if (i == id || id == n_kv) { + continue; + } + + uint32_t nm = 1; + + while (i + nm < n_kv && ids[i + nm] == id + nm) { + nm++; + } + + // move keys + { + const int64_t os = i*k_size_row; + const int64_t od = id*k_size_row; + + memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row); + } + + // move values (note: they are transposed) + { + const int64_t os = i; + const int64_t od = id; + + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el); + } + } + + i += nm - 1; + } + + ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size()); + ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size()); + } +#else + for (uint32_t i = 0; i < ids.size(); ++i) { + const uint32_t id = ids[i]; + + if (i == id || id == ids.size()) { + continue; + } + + uint32_t nm = 1; + + while (i + nm < ids.size() && ids[i + nm] == id + nm) { + nm++; + } + + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + + ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k, + n_embd_k_gqa, nm, + ggml_row_size(layer.k->type, n_embd_k_gqa), + ggml_row_size(layer.k->type, n_embd_k_gqa*i)); + + ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k, + n_embd_k_gqa, nm, + ggml_row_size(layer.k->type, n_embd_k_gqa), + ggml_row_size(layer.k->type, n_embd_k_gqa*id)); + + ggml_tensor * view_v_src; + ggml_tensor * view_v_dst; + + if (cparams.flash_attn) { + // NOTE: the V cache is not transposed when using flash attention + view_v_src = ggml_view_2d(ctx, layer.v, + n_embd_v_gqa, nm, + ggml_row_size(layer.v->type, n_embd_v_gqa), + ggml_row_size(layer.v->type, n_embd_v_gqa*i)); + + view_v_dst = ggml_view_2d(ctx, layer.v, + n_embd_v_gqa, nm, + ggml_row_size(layer.v->type, n_embd_v_gqa), + ggml_row_size(layer.v->type, n_embd_v_gqa*id)); + } else { + view_v_src = ggml_view_2d(ctx, layer.v, + nm, n_embd_v_gqa, + ggml_row_size(layer.v->type, cells.size()), + ggml_row_size(layer.v->type, i)); + + view_v_dst = ggml_view_2d(ctx, layer.v, + nm, n_embd_v_gqa, + ggml_row_size(layer.v->type, cells.size()), + ggml_row_size(layer.v->type, id)); + } + + ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst)); + ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst)); + } + + i += nm - 1; + } + + //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes); +#endif + + return res; +} + +bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { + const uint32_t n_layer = layers.size(); + + const uint32_t n_kv = cells.used_max_p1(); + const uint32_t n_used = cells.get_used(); + + assert(n_used <= n_kv); + + //const int64_t t_start = ggml_time_us(); + + // number of cells moved + uint32_t n_moves = 0; + + // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag) + // - source view, destination view, copy operation + // - x2 for keys and values + //const uint32_t max_moves = max_nodes()/(6*n_layer); + // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 + const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer); + + // determine which KV cells to move where + // + // cell i moves to ids[i] + // + // if ids[i] == i || ids[i] == n_kv, then cell i is not moved + // + auto & ids = defrag_info.ids; + + ids.clear(); + ids.resize(n_kv, n_kv); + + for (uint32_t i0 = 0; i0 < n_used; ++i0) { + if (!cells.is_empty(i0)) { + ids[i0] = i0; + + continue; + } + + // found a hole - fill it with data from the end of the cache + + uint32_t nh = 1; + + // determine the size of the hole + while (i0 + nh < n_used && cells.is_empty(i0 + nh)) { + nh++; + } + + uint32_t nf = 0; + uint32_t is = n_kv - 1; + + // starting from the end, find nh non-empty cells + for (; is > i0; --is) { + if (cells.is_empty(is) || ids[is] != n_kv) { + continue; + } + + // non-empty cell which is not yet moved + nf++; + + if (nf == nh) { + break; + } + } + + // this can only happen if `n_used` is not accurate, which would be a bug + GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh"); + + nf = 0; + + uint32_t i1 = is; + + // are we moving a continuous block of memory? + bool cont = false; + + // should we stop searching for the next move? + bool stop = false; + + // go back and move the nf cells to the hole + for (; i1 < n_kv; ++i1) { + if (cells.is_empty(i1) || ids[i1] != n_kv) { + if (n_moves == max_moves) { + stop = true; + break; + } + + cont = false; + continue; + } + + // this cell goes to (i0 + nf) + ids[i1] = i0 + nf; + + // move the cell meta data + cells.mv(i1, i0 + nf); + + head = n_used; + + if (!cont) { + n_moves++; + cont = true; + } + + nf++; + + if (nf == nh) { + break; + } + } + + if (stop || n_moves == max_moves) { + break; + } + + //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh); + + i0 += nh - 1; + } + + if (n_moves == 0) { + return false; + } + + LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves); + + LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer); + + return true; +} + +bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { + assert(p0 >= 0 && p1 >= 0); + + switch (swa_type) { + case LLAMA_SWA_TYPE_NONE: + { + } break; + case LLAMA_SWA_TYPE_STANDARD: + { + if (p1 - p0 >= (int32_t) n_swa) { + return true; + } + } break; + case LLAMA_SWA_TYPE_CHUNKED: + { + const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; + + if (p0 < pos_chunk_start) { + return true; + } + } break; + } + + return false; +} + +void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + std::vector> cell_ranges; // ranges, from inclusive, to exclusive + uint32_t cell_count = 0; + + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id (or all, when -1) + uint32_t cell_range_begin = cells.size(); + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) { + ++cell_count; + if (cell_range_begin == cells.size()) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != cells.size()) { + cell_ranges.emplace_back(cell_range_begin, i); + cell_range_begin = cells.size(); + } + } + } + + if (cell_range_begin != cells.size()) { + cell_ranges.emplace_back(cell_range_begin, cells.size()); + } + + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + + io.write(&cell_count, sizeof(cell_count)); + + state_write_meta(io, cell_ranges, seq_id); + state_write_data(io, cell_ranges); +} + +void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + uint32_t cell_count; + io.read_to(&cell_count, sizeof(cell_count)); + + bool res = true; + res = res && state_read_meta(io, cell_count, seq_id); + res = res && state_read_data(io, cell_count); + + if (!res) { + if (seq_id == -1) { + clear(); + } else { + seq_rm(seq_id, -1, -1); + } + throw std::runtime_error("failed to restore kv cache"); + } +} + +void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { + for (const auto & range : cell_ranges) { + for (uint32_t i = range.first; i < range.second; ++i) { + std::vector seq_ids; + + for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) { + if (cur == seq_id || seq_id == -1) { + if (cells.seq_has(i, cur)) { + seq_ids.push_back(cur); + } + } + } + + const llama_pos pos = cells.pos_get(i); + const uint32_t n_seq_id = seq_ids.size(); + + io.write(&pos, sizeof(pos)); + io.write(&n_seq_id, sizeof(n_seq_id)); + + for (const auto & seq_id : seq_ids) { + io.write(&seq_id, sizeof(seq_id)); + } + } + } +} + +void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { + const uint32_t v_trans = this->v_trans ? 1 : 0; + const uint32_t n_layer = layers.size(); + + io.write(&v_trans, sizeof(v_trans)); + io.write(&n_layer, sizeof(n_layer)); + + std::vector tmp_buf; + + // Iterate and write all the keys first, each row is a cell + // Get whole range at a time + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Write key type + const int32_t k_type_i = (int32_t)layer.k->type; + io.write(&k_type_i, sizeof(k_type_i)); + + // Write row size of key + const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); + io.write(&k_size_row, sizeof(k_size_row)); + + // Read each range of cells of k_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * k_size_row; + io.write_tensor(layer.k, range.first * k_size_row, buf_size); + } + } + + if (!v_trans) { + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)layer.v->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write row size of value + const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); + io.write(&v_size_row, sizeof(v_size_row)); + + // Read each range of cells of v_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * v_size_row; + io.write_tensor(layer.v, range.first * v_size_row, buf_size); + } + } + } else { + // When v is transposed, we also need the element size and get the element ranges from each row + const uint32_t kv_size = cells.size(); + + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)layer.v->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write element size + const uint32_t v_size_el = ggml_type_size(layer.v->type); + io.write(&v_size_el, sizeof(v_size_el)); + + // Write GQA embedding size + io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); + + // For each row, we get the element values of each cell + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + // Read each range of cells of v_size_el length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t src_offset = (range.first + j * kv_size) * v_size_el; + const size_t buf_size = range_size * v_size_el; + io.write_tensor(layer.v, src_offset, buf_size); + } + } + } + } +} + +bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { + if (dest_seq_id != -1) { + // single sequence + + seq_rm(dest_seq_id, -1, -1); + + llama_sbatch sbatch; + llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); + + batch.n_tokens = cell_count; + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id != 1) { + LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); + return false; + } + + // read the sequence id, but directly discard it - we will use dest_seq_id instead + { + llama_seq_id seq_id; + io.read_to(&seq_id, sizeof(seq_id)); + } + + batch.pos[i] = pos; + batch.n_seq_id[i] = n_seq_id; + batch.seq_id[i] = &dest_seq_id; + } + + const auto head_cur = find_slot(batch); + if (head_cur < 0) { + LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); + return false; + } + + apply_ubatch(head_cur, batch); + + // keep the head at the old position because we will read the KV data into it in state_read_data() + head = head_cur; + + // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values) + // Assume that this is one contiguous block of cells + GGML_ASSERT(head_cur + cell_count <= cells.size()); + GGML_ASSERT(cells.pos_get(head_cur) == batch.pos[0]); + GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == batch.pos[cell_count - 1]); + GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id)); + GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id)); + } else { + // whole KV cache restore + + if (cell_count > cells.size()) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); + return false; + } + + clear(); + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + cells.pos_set(i, pos); + + for (uint32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id; + io.read_to(&seq_id, sizeof(seq_id)); + + if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) { + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max); + return false; + } + + cells.seq_add(i, seq_id); + } + } + + head = 0; + } + + return true; +} + +bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) { + uint32_t v_trans; + uint32_t n_layer; + + io.read_to(&v_trans, sizeof(v_trans)); + io.read_to(&n_layer, sizeof(n_layer)); + + if (n_layer != layers.size()) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size()); + return false; + } + + if (cell_count > cells.size()) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size()); + return false; + } + + if (this->v_trans != (bool) v_trans) { + LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); + return false; + } + + // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Read type of key + int32_t k_type_i_ref; + io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); + const int32_t k_type_i = (int32_t) layer.k->type; + if (k_type_i != k_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); + return false; + } + + // Read row size of key + uint64_t k_size_row_ref; + io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); + const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); + if (k_size_row != k_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); + } + } + + if (!this->v_trans) { + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)layer.v->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read row size of value + uint64_t v_size_row_ref; + io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); + const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); + if (v_size_row != v_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the values for the whole cell range + ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); + } + } + } else { + // For each layer, read the values for each cell (transposed) + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)layer.v->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read element size of value + uint32_t v_size_el_ref; + io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); + const size_t v_size_el = ggml_type_size(layer.v->type); + if (v_size_el != v_size_el_ref) { + LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); + return false; + } + + // Read GQA embedding size + uint32_t n_embd_v_gqa_ref; + io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); + if (n_embd_v_gqa != n_embd_v_gqa_ref) { + LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); + return false; + } + + if (cell_count) { + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (head + j * cells.size()) * v_size_el; + ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + } + } + } + } + + return true; +} + +// +// llama_kv_cache_unified_state +// + +llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {} + +llama_kv_cache_unified_state::llama_kv_cache_unified_state( + llama_memory_status status, + llama_kv_cache_unified * kv) : status(status), kv(kv) { + n_kv = kv->get_size(); + head = 0; + } + +llama_kv_cache_unified_state::llama_kv_cache_unified_state( + llama_memory_status status, + llama_kv_cache_unified * kv, + llama_sbatch sbatch, + std::vector heads, + std::vector ubatches) + : status(status), + kv(kv), + sbatch(std::move(sbatch)), + heads(std::move(heads)), + ubatches(std::move(ubatches)) { + } + +llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default; + +bool llama_kv_cache_unified_state::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_unified_state::apply() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + kv->apply_ubatch(heads[i_next], ubatches[i_next]); + + n_kv = kv->get_n_kv(); + head = heads[i_next]; + + return true; +} + +std::vector & llama_kv_cache_unified_state::out_ids() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return sbatch.out_ids; +} + +llama_memory_status llama_kv_cache_unified_state::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ubatches[i_next]; +} + +uint32_t llama_kv_cache_unified_state::get_n_kv() const { + return n_kv; +} + +ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const { + return kv->get_k(ctx, il, n_kv); +} + +ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const { + return kv->get_v(ctx, il, n_kv); +} + +ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const { + return kv->cpy_k(ctx, k_cur, il, head); +} + +ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const { + return kv->cpy_v(ctx, v_cur, il, head); +} + +void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const { + kv->set_input_k_shift(dst); +} + +void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { + kv->set_input_kq_mask(dst, ubatch, causal_attn); +} + +void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { + kv->set_input_pos_bucket(dst, ubatch); +} + +uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { + // the FA kernels require padding to avoid extra runtime boundary checks + return cparams.flash_attn ? 256u : 32u; +} diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h new file mode 100644 index 000000000..1f1d44b97 --- /dev/null +++ b/src/llama-kv-cache-unified.h @@ -0,0 +1,278 @@ +#pragma once + +#include "llama-batch.h" +#include "llama-graph.h" +#include "llama-kv-cache.h" +#include "llama-kv-cells.h" + +#include +#include + +struct llama_cparams; +struct llama_hparams; +struct llama_model; +struct llama_context; + +// +// llama_kv_cache_unified +// + +class llama_kv_cache_unified : public llama_kv_cache { +public: + static uint32_t get_padding(const llama_cparams & cparams); + + // this callback is used to filter out layers that should not be included in the cache + using layer_filter_cb = std::function; + + llama_kv_cache_unified( + const llama_model & model, + layer_filter_cb && filter, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type); + + ~llama_kv_cache_unified() = default; + + // + // llama_memory_i + // + + void clear() override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + // + // llama_kv_cache + // + + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) override; + + llama_memory_state_ptr init_full() override; + + bool update(llama_context & lctx) override; + + void defrag_sched(float thold) override; + + bool get_can_shift() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + + // + // llama_kv_cache_unified specific API + // + + uint32_t get_size() const; + + // + // graph_build API + // + + uint32_t get_n_kv() const; + + // get views of the current state of the cache + ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const; + ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const; + + // store k_cur and v_cur in the cache based on the provided head location + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const; + ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const; + + // + // preparation API + // + + // find places for the provided ubatches in the cache, returns the head locations + // return empty vector on failure + std::vector prepare(const std::vector & ubatches); + + // return the cell position where we can insert the ubatch + // return -1 on failure to find a contiguous slot of kv cells + int32_t find_slot(const llama_ubatch & ubatch) const; + + // emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens) + void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch); + + // + // set_input API + // + + void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; + void set_input_k_shift (ggml_tensor * dst) const; + void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + +private: + const llama_model & model; + const llama_hparams & hparams; + + struct kv_layer { + // layer index in the model + // note: can be different from the layer index in the KV cache + uint32_t il; + + ggml_tensor * k; + ggml_tensor * v; + }; + + bool do_defrag = false; + bool v_trans = true; // the value tensor is transposed + + // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) + // note: this is not part of the KV state and it's only used to speed-up the find_slot() method + uint32_t head = 0; + + const uint32_t n_seq_max = 1; + + // required padding + const uint32_t n_pad = 1; + + // SWA + const uint32_t n_swa = 0; + + const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; + + std::vector ctxs; + std::vector bufs; + + llama_kv_cells_unified cells; + + std::vector layers; + + // model layer id -> KV cache layer id + std::unordered_map map_layer_ids; + + // defrag + struct { + std::vector ids; + } defrag_info; + + // return true if cells have been moved + bool defrag_prepare(int32_t n_max_nodes); + + size_t total_size() const; + + size_t size_k_bytes() const; + size_t size_v_bytes() const; + + bool is_masked_swa(llama_pos p0, llama_pos p1) const; + + ggml_tensor * build_rope_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * shift, + ggml_tensor * factors, + float freq_base, + float freq_scale) const; + + llm_graph_result_ptr build_graph_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const; + + llm_graph_result_ptr build_graph_defrag( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const; + + void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; + void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; + + bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); + bool state_read_data(llama_io_read_i & io, uint32_t cell_count); +}; + +class llama_kv_cache_unified_state : public llama_memory_state_i { +public: + // used for errors + llama_kv_cache_unified_state(llama_memory_status status); + + // used to create a full-cache state + llama_kv_cache_unified_state( + llama_memory_status status, + llama_kv_cache_unified * kv); + + // used to create a state from a batch + llama_kv_cache_unified_state( + llama_memory_status status, + llama_kv_cache_unified * kv, + llama_sbatch sbatch, + std::vector heads, + std::vector ubatches); + + virtual ~llama_kv_cache_unified_state(); + + // + // llama_memory_state_i + // + + bool next() override; + bool apply() override; + + std::vector & out_ids() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_unified_state specific API + // + + uint32_t get_n_kv() const; + + // get views of the current state of the cache + ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; + ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; + + // store k_cur and v_cur in the cache based on the provided head location + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const; + ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const; + + void set_input_k_shift(ggml_tensor * dst) const; + + void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; + void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + +private: + const llama_memory_status status; + + llama_kv_cache_unified * kv; + + llama_sbatch sbatch; + + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector heads; + std::vector ubatches; + + // + // data needed for building the compute graph for the current ubatch: + // + + // a heuristic, to avoid attending the full cache if it is not yet utilized + // as the cache gets filled, the benefit from this heuristic disappears + int32_t n_kv; + + // the beginning of the current slot in which the ubatch will be inserted + int32_t head; +}; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index a2624d715..aefd23e32 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1,2827 +1 @@ #include "llama-kv-cache.h" - -#include "llama-impl.h" -#include "llama-batch.h" -#include "llama-cparams.h" -#include "llama-model.h" -#include "llama-context.h" - -#include -#include -#include -#include -#include -#include - -// -// llama_kv_cache_unified -// - -uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { - // the FA kernels require padding to avoid extra runtime boundary checks - return cparams.flash_attn ? 256u : 32u; -} - -llama_kv_cache_unified::llama_kv_cache_unified( - const llama_model & model, - layer_filter_cb && filter, - ggml_type type_k, - ggml_type type_v, - bool v_trans, - bool offload, - uint32_t kv_size, - uint32_t n_seq_max, - uint32_t n_pad, - uint32_t n_swa, - llama_swa_type swa_type) : - model(model), hparams(model.hparams), v_trans(v_trans), - n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { - - GGML_ASSERT(kv_size % n_pad == 0); - - // create a context for each buffer type - std::map ctx_map; - auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { - auto it = ctx_map.find(buft); - if (it == ctx_map.end()) { - ggml_init_params params = { - /*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()), - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - - ggml_context * ctx = ggml_init(params); - if (!ctx) { - return nullptr; - } - - ctx_map[buft] = ctx; - ctxs.emplace_back(ctx); - - return ctx; - } - - return it->second; - }; - - head = 0; - size = kv_size; - used = 0; - - cells.resize(kv_size); - - for (uint32_t il = 0; il < hparams.n_layer; il++) { - if (filter && !filter(il)) { - LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il); - continue; - } - - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - const char * dev_name = "CPU"; - - ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); - - if (offload) { - auto * dev = model.dev_layer(il); - buft = ggml_backend_dev_buffer_type(dev); - - dev_name = ggml_backend_dev_name(dev); - } - - LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name); - - ggml_context * ctx = ctx_for_buft(buft); - if (!ctx) { - throw std::runtime_error("failed to create ggml context for kv cache"); - } - - ggml_tensor * k; - ggml_tensor * v; - - k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size); - v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size); - - ggml_format_name(k, "cache_k_l%d", il); - ggml_format_name(v, "cache_v_l%d", il); - - map_layer_ids[il] = layers.size(); - layers.push_back({ il, k, v }); - } - - // allocate tensors and initialize the buffers to avoid NaNs in the padding - for (auto it : ctx_map) { - auto * buft = it.first; - auto * ctx = it.second; - - ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); - if (!buf) { - throw std::runtime_error("failed to allocate buffer for kv cache"); - } - - LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); - - ggml_backend_buffer_clear(buf, 0); - bufs.emplace_back(buf); - } - - { - const size_t memory_size_k = size_k_bytes(); - const size_t memory_size_v = size_v_bytes(); - - LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, - (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, - ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), - ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); - } -} - -void llama_kv_cache_unified::clear() { - for (uint32_t i = 0; i < size; ++i) { - cells[i].pos = -1; - cells[i].seq_id.clear(); - } - - head = 0; - used = 0; - - for (auto & buf : bufs) { - ggml_backend_buffer_clear(buf.get(), 0); - } -} - -bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - uint32_t new_head = size; - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].pos >= p0 && cells[i].pos < p1) { - if (seq_id < 0) { - cells[i].seq_id.clear(); - } else if (cells[i].has_seq_id(seq_id)) { - cells[i].seq_id.erase(seq_id); - } else { - continue; - } - - if (cells[i].is_empty()) { - // keep count of the number of used cells - if (cells[i].pos >= 0) { - used--; - } - - cells[i].pos = -1; - - if (new_head == size) { - new_head = i; - } - } - } - } - - // If we freed up a slot, set head to it so searching can start there. - if (new_head != size && new_head < head) { - head = new_head; - } - - return true; -} - -void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - if (seq_id_src == seq_id_dst) { - return; - } - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - // otherwise, this is the KV of a Transformer-like model - head = 0; - - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].has_seq_id(seq_id_src) && cells[i].pos >= p0 && cells[i].pos < p1) { - cells[i].seq_id.insert(seq_id_dst); - } - } -} - -void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { - uint32_t new_head = size; - - for (uint32_t i = 0; i < size; ++i) { - if (!cells[i].has_seq_id(seq_id)) { - if (cells[i].pos >= 0) { - used--; - } - - cells[i].pos = -1; - cells[i].seq_id.clear(); - - if (new_head == size){ - new_head = i; - } - } else { - cells[i].seq_id.clear(); - cells[i].seq_id.insert(seq_id); - } - } - - // If we freed up a slot, set head to it so searching can start there. - if (new_head != size && new_head < head) { - head = new_head; - } -} - -void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - if (delta == 0) { - return; - } - - uint32_t new_head = size; - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - // If there is no range then return early to avoid looping over the - if (p0 == p1) { - return; - } - - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) { - has_shift = true; - - cells[i].pos += delta; - cells[i].delta += delta; - - if (cells[i].pos < 0) { - if (!cells[i].is_empty()) { - used--; - } - cells[i].pos = -1; - cells[i].seq_id.clear(); - if (new_head == size) { - new_head = i; - } - } - } - } - - // If we freed up a slot, set head to it so searching can start there. - // Otherwise we just start the next search from the beginning. - head = new_head != size ? new_head : 0; -} - -void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { - if (d == 1) { - return; - } - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - // If there is no range then return early to avoid looping over the cache. - if (p0 == p1) { - return; - } - - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) { - has_shift = true; - - { - llama_pos p_old = cells[i].pos; - cells[i].pos /= d; - cells[i].delta += cells[i].pos - p_old; - } - } - } -} - -llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { - llama_pos result = std::numeric_limits::max(); - - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].has_seq_id(seq_id)) { - result = std::min(result, cells[i].pos); - } - } - - if (result == std::numeric_limits::max()) { - result = -1; - } - - return result; -} - -llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { - llama_pos result = -1; - - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].has_seq_id(seq_id)) { - result = std::max(result, cells[i].pos); - } - } - - return result; -} - -void llama_kv_cache_unified::restore() { - for (const auto & [id, cell] : recovery.cells) { - // TODO: move to new `struct kv_cells` - const bool is_empty0 = cells[id].is_empty(); - const bool is_empty1 = cell.is_empty(); - - if (!is_empty0 && is_empty1) { - used--; - } else if (is_empty0 && !is_empty1) { - used++; - } - - cells[id] = cell; - } - - recovery.clear(); -} - -void llama_kv_cache_unified::commit() { - if (recovery.cells.empty()) { - LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n", - __func__, "https://github.com/ggml-org/llama.cpp/pull/13194"); - return; - } - - recovery.clear(); -} - -bool llama_kv_cache_unified::update(llama_context & lctx) { - bool need_reserve = false; - - auto * sched = lctx.get_sched(); - - if (has_shift) { - if (!get_can_shift()) { - GGML_ABORT("The current KV cache / model configuration does not support K-shift"); - } - - LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__); - - // apply K-shift if needed - if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { - ggml_backend_sched_reset(sched); - - auto * gf = lctx.graph_init(); - - auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf); - - ggml_backend_sched_alloc_graph(sched, gf); - - res->set_inputs(nullptr); - - lctx.graph_compute(gf, false); - - need_reserve = true; - } - - { - has_shift = false; - - for (uint32_t i = 0; i < size; ++i) { - cells[i].delta = 0; - } - } - } - - if (do_defrag) { - LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); - - if (defrag_prepare(lctx.graph_max_nodes())) { - ggml_backend_sched_reset(sched); - - auto * gf = lctx.graph_init(); - - auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf); - - ggml_backend_sched_alloc_graph(sched, gf); - - res->set_inputs(nullptr); - - lctx.graph_compute(gf, false); - - need_reserve = true; - } - - do_defrag = false; - } - - return need_reserve; -} - -void llama_kv_cache_unified::defrag_sched(float thold) { - // - do not defrag small contexts (i.e. < 2048 tokens) - // - count the padding towards the number of used tokens - const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + n_pad)/n)) : 0.0f; - - // queue defragmentation for next llama_kv_cache_update - if (fragmentation > thold) { - LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); - - do_defrag = true; - } -} - -void llama_kv_cache_unified::set_full() { - n = size; - - // when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not - // affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views. - // we should only guarantee that the head position won't cause out-of-bounds view of the K, V tensors, so - // setting it to 0 is the simplest way to achieve that - // ref: https://github.com/ggml-org/llama.cpp/issues/13359 - head = 0; -} - -llama_sbatch llama_kv_cache_unified::sbatch_init(const llama_batch & batch, bool logits_all) { - return llama_sbatch(batch, hparams.n_embd, true, logits_all); -} - -llama_ubatch llama_kv_cache_unified::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const { - GGML_UNUSED(embd_pooled); - return sbatch.split_simple(n_ubatch); -} - -bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { - const uint32_t n_tokens = ubatch.n_tokens; - - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (head > used + 2*ubatch.n_tokens) { - head = 0; - } - - // otherwise, one cell per token. - - if (n_tokens > size) { - LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size); - return false; - } - -//#define FIND_SLOT_DEBUG 1 -#if FIND_SLOT_DEBUG - LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa); - - // for debugging - { - std::string ss; - if (n_swa > 0) { - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].pos == -1) { - ss += '.'; - } else { - ss += std::to_string(*cells[i].seq_id.begin()); - } - if (i%256 == 255) { - ss += '\n'; - } - } - } - LLAMA_LOG_WARN("\n%s\n", ss.c_str()); - } -#endif - - uint32_t n_tested = 0; - - while (true) { - if (head + n_tokens > size) { - n_tested += size - head; - head = 0; - continue; - } - - bool found = true; - for (uint32_t i = 0; i < n_tokens; i++) { - if (cells[head + i].pos >= 0) { - found = false; - head += i + 1; - n_tested += i + 1; - break; - } - } - - if (found) { - break; - } - - if (n_tested >= size) { - //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); - return false; - } - } - - for (uint32_t i = 0; i < n_tokens; ++i) { - // remember the original state - if (recovery.cells.find(head + i) == recovery.cells.end()) { - recovery.cells[head + i] = cells[head + i]; - } - - cells[head + i].pos = ubatch.pos[i]; - - for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) { - cells[head + i].seq_id.insert(ubatch.seq_id[i][j]); - } - } - - used += n_tokens; - - // a heuristic, to avoid attending the full cache if it is not yet utilized - // after enough generations, the benefit from this heuristic disappears - // if we start defragmenting the cache, the benefit from this will be more important - n = std::min(size, std::max(n_pad, GGML_PAD(cell_max(), n_pad))); - -#ifdef FIND_SLOT_DEBUG - LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa); -#endif - - return true; -} - -bool llama_kv_cache_unified::get_can_shift() const { - return true; -} - -uint32_t llama_kv_cache_unified::get_n() const { - return n; -} - -uint32_t llama_kv_cache_unified::get_size() const { - return size; -} - -ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const { - const int32_t ikv = map_layer_ids.at(il); - - auto * k = layers[ikv].k; - - return ggml_view_3d(ctx, k, - hparams.n_embd_head_k, hparams.n_head_kv(il), n, - ggml_row_size(k->type, hparams.n_embd_head_k), - ggml_row_size(k->type, hparams.n_embd_k_gqa(il)), - 0); -} - -ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) const { - const int32_t ikv = map_layer_ids.at(il); - - auto * v = layers[ikv].v; - - if (!v_trans) { - // note: v->nb[1] <= v->nb[2] - return ggml_view_3d(ctx, v, - hparams.n_embd_head_v, hparams.n_head_kv(il), n, - ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] - ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2] - 0); - } - - // note: v->nb[1] > v->nb[2] - return ggml_view_3d(ctx, v, - n, hparams.n_head_kv(il), hparams.n_embd_head_v, - ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1] - ggml_row_size(v->type, v->ne[1]), // v->nb[2] - 0); -} - -ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const { - const int32_t ikv = map_layer_ids.at(il); - - auto * k = layers[ikv].k; - - const int64_t n_tokens = k_cur->ne[2]; - - ggml_tensor * k_view = ggml_view_1d(ctx, k, - n_tokens*hparams.n_embd_k_gqa(il), - ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head); - - return ggml_cpy(ctx, k_cur, k_view); -} - -ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const { - const int32_t ikv = map_layer_ids.at(il); - - auto * v = layers[ikv].v; - - const int64_t n_tokens = v_cur->ne[2]; - - v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens); - - ggml_tensor * v_view = nullptr; - - if (!v_trans) { - v_view = ggml_view_1d(ctx, v, - n_tokens*hparams.n_embd_v_gqa(il), - ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head); - } else { - // note: the V cache is transposed when not using flash attention - v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il), - (v->ne[1])*ggml_element_size(v), - ( head)*ggml_element_size(v)); - - v_cur = ggml_transpose(ctx, v_cur); - } - - return ggml_cpy(ctx, v_cur, v_view); -} - -void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax) { - // no pruning is needed when the cache does not use SWA - GGML_ASSERT(swa_type != LLAMA_SWA_TYPE_NONE && "do not prune non-SWA cache"); - - int n_attended = 0; - - for (uint32_t i = 0; i < size; ++i) { - const llama_pos p0 = cells[i].pos; - - if (p0 <= pmin && !is_masked_swa(p0, pmin)) { - n_attended++; - } - - if (is_masked_swa(p0, pmax)) { - if (seq_id < 0) { - cells[i].seq_id.clear(); - } else if (cells[i].has_seq_id(seq_id)) { - cells[i].seq_id.erase(seq_id); - } else { - continue; - } - - if (cells[i].is_empty()) { - // keep count of the number of used cells - if (cells[i].pos >= 0) { - used--; - } - - cells[i].pos = -1; - } - } - } - - if (n_attended < std::min(n_swa, pmin)) { - LLAMA_LOG_WARN("%s: partial SWA cache detected - possible loss of information, pmin = %d, n_attended = %d, n_swa = %d\n", __func__, pmin, n_attended, n_swa); - } -} - -void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { - const int64_t n_tokens = ubatch->n_tokens; - const int64_t n_seq_tokens = ubatch->n_seq_tokens; - const int64_t n_seqs = ubatch->n_seqs; - - GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - float * data = (float *) dst->data; - - const int64_t n_kv = n; - - // Use only the previous KV cells of the correct sequence for each token of the ubatch. - // It's assumed that if a token in the batch has multiple sequences, they are equivalent. - // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch: - // Causal mask: - // xxx------- - // xxxx------ - // xxxxx----- - // Non-causal mask: - // xxxxx----- - // xxxxx----- - // xxxxx----- - // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 - for (int h = 0; h < 1; ++h) { - for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch->seq_id[s][0]; - - for (int j = 0; j < n_seq_tokens; ++j) { - const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j]; - - for (int i = 0; i < n_kv; ++i) { - const llama_pos p0 = cells[i].pos; - - bool masked = false; - - // mask the token if not the same sequence - masked = masked || (!cells[i].has_seq_id(seq_id)); - - // mask future tokens - masked = masked || (causal_attn && p0 > p1); - - // apply SWA if any - masked = masked || (is_masked_swa(p0, p1)); - - float f = 0.0f; - - if (masked) { - f = -INFINITY; - } else if (hparams.use_alibi) { - f = -std::abs(p0 - p1); - } - - data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; - } - } - } - - // mask padded tokens - if (data) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { - data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; - } - } - } - } -} - -void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { - GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - - int32_t * data = (int32_t *) dst->data; - - for (uint32_t i = 0; i < size; ++i) { - data[i] = cells[i].delta; - } -} - -void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { - const int64_t n_tokens = ubatch->n_tokens; - - GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing - - int32_t * data = (int32_t *) dst->data; - - const int64_t n_kv = n; - - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - for (int i = 0; i < n_kv; ++i) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false); - } - } - } -} - -size_t llama_kv_cache_unified::total_size() const { - size_t size = 0; - - for (const auto & buf : bufs) { - size += ggml_backend_buffer_get_size(buf.get()); - } - - return size; -} - -size_t llama_kv_cache_unified::size_k_bytes() const { - size_t size_k_bytes = 0; - - for (const auto & layer : layers) { - size_k_bytes += ggml_nbytes(layer.k); - } - - return size_k_bytes; -} - -size_t llama_kv_cache_unified::size_v_bytes() const { - size_t size_v_bytes = 0; - - for (const auto & layer : layers) { - size_v_bytes += ggml_nbytes(layer.v); - } - - return size_v_bytes; -} - -ggml_tensor * llama_kv_cache_unified::build_rope_shift( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_tensor * cur, - ggml_tensor * shift, - ggml_tensor * factors, - float freq_base, - float freq_scale) const { - const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; - - const auto & yarn_ext_factor = cparams.yarn_ext_factor; - const auto & yarn_beta_fast = cparams.yarn_beta_fast; - const auto & yarn_beta_slow = cparams.yarn_beta_slow; - - const auto & n_rot = hparams.n_rot; - const auto & rope_type = hparams.rope_type; - - // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly. - // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. - const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor; - - ggml_tensor * tmp; - - if (ggml_is_quantized(cur->type)) { - // dequantize to f32 -> RoPE -> quantize back - tmp = ggml_cast(ctx, cur, GGML_TYPE_F32); - - tmp = ggml_rope_ext(ctx, tmp, - shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); - - tmp = ggml_cpy(ctx, tmp, cur); - } else { - // we rotate only the first n_rot dimensions - tmp = ggml_rope_ext_inplace(ctx, cur, - shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); - } - - return tmp; -} - -class llm_graph_input_k_shift : public llm_graph_input_i { -public: - llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} - virtual ~llm_graph_input_k_shift() = default; - - void set_input(const llama_ubatch * ubatch) override; - - ggml_tensor * k_shift; // I32 [kv_size] - - const llama_kv_cache_unified * kv_self; -}; - -void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { - GGML_UNUSED(ubatch); - - if (k_shift) { - kv_self->set_input_k_shift(k_shift); - } -} - -llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_cgraph * gf) const { - auto res = std::make_unique(); - - const auto & n_embd_head_k = hparams.n_embd_head_k; - //const auto & n_embd_head_v = hparams.n_embd_head_v; - - //GGML_ASSERT(kv_self->size == n_ctx); - - auto inp = std::make_unique(this); - - inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx); - ggml_set_input(inp->k_shift); - - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const int64_t n_head_kv = hparams.n_head_kv(il); - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - - const float freq_base_l = model.get_rope_freq_base (cparams, il); - const float freq_scale_l = model.get_rope_freq_scale(cparams, il); - - ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); - - ggml_tensor * k = - ggml_view_3d(ctx, layer.k, - n_embd_head_k, n_head_kv, size, - ggml_row_size(layer.k->type, n_embd_head_k), - ggml_row_size(layer.k->type, n_embd_k_gqa), - 0); - - ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); - - ggml_build_forward_expand(gf, cur); - } - - res->add_input(std::move(inp)); - - return res; -} - -llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_cgraph * gf) const { - auto res = std::make_unique(); - - const auto & ids = defrag_info.ids; - -#if 0 - // CPU defrag - // - // TODO: optimizations are possible: - // - multiple threads - // - avoid copying to the host memory when already there - // - // likely not worth the effort, as we have ggml_graph based defrag - // - - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - - const uint32_t kv_size = size; - - std::vector buf_k; - std::vector buf_v; - - for (uint32_t il = 0; il < n_layer; ++il) { - const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); - const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size); - - const size_t v_size_el = ggml_type_size(v_l[il]->type); - const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size); - - buf_k.resize(k_size); - buf_v.resize(v_size); - - ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size()); - ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size()); - - // batch move [i, i+nm) to [id, id+nm) - // note: cells can move only to a lower index - for (uint32_t i = 0; i < n_kv; ++i) { - const uint32_t id = ids[i]; - - if (i == id || id == n_kv) { - continue; - } - - uint32_t nm = 1; - - while (i + nm < n_kv && ids[i + nm] == id + nm) { - nm++; - } - - // move keys - { - const int64_t os = i*k_size_row; - const int64_t od = id*k_size_row; - - memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row); - } - - // move values (note: they are transposed) - { - const int64_t os = i; - const int64_t od = id; - - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el); - } - } - - i += nm - 1; - } - - ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size()); - ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size()); - } -#else - for (uint32_t i = 0; i < ids.size(); ++i) { - const uint32_t id = ids[i]; - - if (i == id || id == ids.size()) { - continue; - } - - uint32_t nm = 1; - - while (i + nm < ids.size() && ids[i + nm] == id + nm) { - nm++; - } - - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); - - ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k, - n_embd_k_gqa, nm, - ggml_row_size(layer.k->type, n_embd_k_gqa), - ggml_row_size(layer.k->type, n_embd_k_gqa*i)); - - ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k, - n_embd_k_gqa, nm, - ggml_row_size(layer.k->type, n_embd_k_gqa), - ggml_row_size(layer.k->type, n_embd_k_gqa*id)); - - ggml_tensor * view_v_src; - ggml_tensor * view_v_dst; - - if (cparams.flash_attn) { - // NOTE: the V cache is not transposed when using flash attention - view_v_src = ggml_view_2d(ctx, layer.v, - n_embd_v_gqa, nm, - ggml_row_size(layer.v->type, n_embd_v_gqa), - ggml_row_size(layer.v->type, n_embd_v_gqa*i)); - - view_v_dst = ggml_view_2d(ctx, layer.v, - n_embd_v_gqa, nm, - ggml_row_size(layer.v->type, n_embd_v_gqa), - ggml_row_size(layer.v->type, n_embd_v_gqa*id)); - } else { - view_v_src = ggml_view_2d(ctx, layer.v, - nm, n_embd_v_gqa, - ggml_row_size(layer.v->type, size), - ggml_row_size(layer.v->type, i)); - - view_v_dst = ggml_view_2d(ctx, layer.v, - nm, n_embd_v_gqa, - ggml_row_size(layer.v->type, size), - ggml_row_size(layer.v->type, id)); - } - - ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst)); - ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst)); - } - - i += nm - 1; - } - - //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes); -#endif - - return res; -} - -bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { - const uint32_t n_layer = layers.size(); - - const uint32_t n_kv = cell_max(); - const uint32_t n_used = used; - - assert(n_used <= n_kv); - - //const int64_t t_start = ggml_time_us(); - - // number of cells moved - uint32_t n_moves = 0; - - // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag) - // - source view, destination view, copy operation - // - x2 for keys and values - //const uint32_t max_moves = max_nodes()/(6*n_layer); - // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 - const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer); - - // determine which KV cells to move where - // - // cell i moves to ids[i] - // - // if ids[i] == i || ids[i] == n_kv, then cell i is not moved - // - auto & ids = defrag_info.ids; - - ids.clear(); - ids.resize(n_kv, n_kv); - - for (uint32_t i0 = 0; i0 < n_used; ++i0) { - const auto & cell0 = cells[i0]; - - if (!cell0.is_empty()) { - ids[i0] = i0; - - continue; - } - - // found a hole - fill it with data from the end of the cache - - uint32_t nh = 1; - - // determine the size of the hole - while (i0 + nh < n_used && cells[i0 + nh].is_empty()) { - nh++; - } - - uint32_t nf = 0; - uint32_t is = n_kv - 1; - - // starting from the end, find nh non-empty cells - for (; is > i0; --is) { - const auto & cell1 = cells[is]; - - if (cell1.is_empty() || ids[is] != n_kv) { - continue; - } - - // non-empty cell which is not yet moved - nf++; - - if (nf == nh) { - break; - } - } - - // this can only happen if `n_used` is not accurate, which would be a bug - GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh"); - - nf = 0; - - uint32_t i1 = is; - - // are we moving a continuous block of memory? - bool cont = false; - - // should we stop searching for the next move? - bool stop = false; - - // go back and move the nf cells to the hole - for (; i1 < n_kv; ++i1) { - auto & cell1 = cells[i1]; - - if (cell1.is_empty() || ids[i1] != n_kv) { - if (n_moves == max_moves) { - stop = true; - break; - } - - cont = false; - continue; - } - - // this cell goes to (i0 + nf) - ids[i1] = i0 + nf; - - // move the cell meta data - cells[i0 + nf] = cell1; - - // clear the old cell and move the head there - cell1 = kv_cell(); - head = n_used; - - if (!cont) { - n_moves++; - cont = true; - } - - nf++; - - if (nf == nh) { - break; - } - } - - if (stop || n_moves == max_moves) { - break; - } - - //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh); - - i0 += nh - 1; - } - - if (n_moves == 0) { - return false; - } - - LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves); - - LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer); - - return true; -} - -uint32_t llama_kv_cache_unified::cell_max() const { - for (uint32_t i = size; i > 0; --i) { - const kv_cell & cell = cells[i - 1]; - - if (cell.pos >= 0 && !cell.is_empty()) { - return i; - } - } - - return 0; -} - -bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { - if (p0 < 0) { - return true; - } - - switch (swa_type) { - case LLAMA_SWA_TYPE_NONE: - { - } break; - case LLAMA_SWA_TYPE_STANDARD: - { - if (p1 - p0 >= (int32_t) n_swa) { - return true; - } - } break; - case LLAMA_SWA_TYPE_CHUNKED: - { - const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; - - if (p0 < pos_chunk_start) { - return true; - } - } break; - } - - return false; -} - -void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { - std::vector> cell_ranges; // ranges, from inclusive, to exclusive - uint32_t cell_count = 0; - - // Count the number of cells with the specified seq_id - // Find all the ranges of cells with this seq id (or all, when -1) - uint32_t cell_range_begin = size; - for (uint32_t i = 0; i < size; ++i) { - const auto & cell = cells[i]; - if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { - ++cell_count; - if (cell_range_begin == size) { - cell_range_begin = i; - } - } else { - if (cell_range_begin != size) { - cell_ranges.emplace_back(cell_range_begin, i); - cell_range_begin = size; - } - } - } - if (cell_range_begin != size) { - cell_ranges.emplace_back(cell_range_begin, size); - } - - // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count - uint32_t cell_count_check = 0; - for (const auto & range : cell_ranges) { - cell_count_check += range.second - range.first; - } - GGML_ASSERT(cell_count == cell_count_check); - - io.write(&cell_count, sizeof(cell_count)); - - state_write_meta(io, cell_ranges, seq_id); - state_write_data(io, cell_ranges); -} - -void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) { - uint32_t cell_count; - io.read_to(&cell_count, sizeof(cell_count)); - - bool res = true; - res = res && state_read_meta(io, cell_count, seq_id); - res = res && state_read_data(io, cell_count); - - if (!res) { - if (seq_id == -1) { - clear(); - } else { - seq_rm(seq_id, -1, -1); - } - throw std::runtime_error("failed to restore kv cache"); - } -} - -void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { - for (const auto & range : cell_ranges) { - for (uint32_t i = range.first; i < range.second; ++i) { - const auto & cell = cells[i]; - const llama_pos pos = cell.pos; - const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0; - - io.write(&pos, sizeof(pos)); - io.write(&n_seq_id, sizeof(n_seq_id)); - - if (n_seq_id) { - for (auto seq_id : cell.seq_id) { - io.write(&seq_id, sizeof(seq_id)); - } - } - } - } -} - -void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { - const uint32_t v_trans = this->v_trans ? 1 : 0; - const uint32_t n_layer = layers.size(); - - io.write(&v_trans, sizeof(v_trans)); - io.write(&n_layer, sizeof(n_layer)); - - std::vector tmp_buf; - - // Iterate and write all the keys first, each row is a cell - // Get whole range at a time - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); - - // Write key type - const int32_t k_type_i = (int32_t)layer.k->type; - io.write(&k_type_i, sizeof(k_type_i)); - - // Write row size of key - const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); - io.write(&k_size_row, sizeof(k_size_row)); - - // Read each range of cells of k_size length each into tmp_buf and write out - for (const auto & range : cell_ranges) { - const size_t range_size = range.second - range.first; - const size_t buf_size = range_size * k_size_row; - io.write_tensor(layer.k, range.first * k_size_row, buf_size); - } - } - - if (!v_trans) { - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Write value type - const int32_t v_type_i = (int32_t)layer.v->type; - io.write(&v_type_i, sizeof(v_type_i)); - - // Write row size of value - const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); - io.write(&v_size_row, sizeof(v_size_row)); - - // Read each range of cells of v_size length each into tmp_buf and write out - for (const auto & range : cell_ranges) { - const size_t range_size = range.second - range.first; - const size_t buf_size = range_size * v_size_row; - io.write_tensor(layer.v, range.first * v_size_row, buf_size); - } - } - } else { - // When v is transposed, we also need the element size and get the element ranges from each row - const uint32_t kv_size = size; - - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Write value type - const int32_t v_type_i = (int32_t)layer.v->type; - io.write(&v_type_i, sizeof(v_type_i)); - - // Write element size - const uint32_t v_size_el = ggml_type_size(layer.v->type); - io.write(&v_size_el, sizeof(v_size_el)); - - // Write GQA embedding size - io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); - - // For each row, we get the element values of each cell - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - // Read each range of cells of v_size_el length each into tmp_buf and write out - for (const auto & range : cell_ranges) { - const size_t range_size = range.second - range.first; - const size_t src_offset = (range.first + j * kv_size) * v_size_el; - const size_t buf_size = range_size * v_size_el; - io.write_tensor(layer.v, src_offset, buf_size); - } - } - } - } -} - -bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { - if (dest_seq_id != -1) { - // single sequence - - seq_rm(dest_seq_id, -1, -1); - - llama_sbatch sbatch; - llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); - - batch.n_tokens = cell_count; - - for (uint32_t i = 0; i < cell_count; ++i) { - llama_pos pos; - uint32_t n_seq_id; - - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); - - if (n_seq_id != 0) { - LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); - return false; - } - - batch.pos[i] = pos; - batch.n_seq_id[i] = 1; - batch.seq_id[i] = &dest_seq_id; - } - - if (!find_slot(batch)) { - LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); - return false; - } - - commit(); - - // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) - // Assume that this is one contiguous block of cells - GGML_ASSERT(head + cell_count <= size); - GGML_ASSERT(cells[head].pos == batch.pos[0]); - GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]); - GGML_ASSERT(cells[head].has_seq_id(dest_seq_id)); - GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id)); - } else { - // whole KV cache restore - - if (cell_count > size) { - LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); - return false; - } - - clear(); - - for (uint32_t i = 0; i < cell_count; ++i) { - kv_cell & cell = cells[i]; - - llama_pos pos; - uint32_t n_seq_id; - - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); - - cell.pos = pos; - - for (uint32_t j = 0; j < n_seq_id; ++j) { - llama_seq_id seq_id; - io.read_to(&seq_id, sizeof(seq_id)); - - if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) { - LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max); - return false; - } - - cell.seq_id.insert(seq_id); - } - } - - head = 0; - used = cell_count; - } - - return true; -} - -bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) { - uint32_t v_trans; - uint32_t n_layer; - - io.read_to(&v_trans, sizeof(v_trans)); - io.read_to(&n_layer, sizeof(n_layer)); - - if (n_layer != layers.size()) { - LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size()); - return false; - } - if (cell_count > size) { - LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size); - return false; - } - if (this->v_trans != (bool) v_trans) { - LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); - return false; - } - - // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); - - // Read type of key - int32_t k_type_i_ref; - io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); - const int32_t k_type_i = (int32_t) layer.k->type; - if (k_type_i != k_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); - return false; - } - - // Read row size of key - uint64_t k_size_row_ref; - io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); - const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); - if (k_size_row != k_size_row_ref) { - LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); - return false; - } - - if (cell_count) { - // Read and set the keys for the whole cell range - ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); - } - } - - if (!this->v_trans) { - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Read type of value - int32_t v_type_i_ref; - io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)layer.v->type; - if (v_type_i != v_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); - return false; - } - - // Read row size of value - uint64_t v_size_row_ref; - io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); - const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); - if (v_size_row != v_size_row_ref) { - LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); - return false; - } - - if (cell_count) { - // Read and set the values for the whole cell range - ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); - } - } - } else { - // For each layer, read the values for each cell (transposed) - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Read type of value - int32_t v_type_i_ref; - io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)layer.v->type; - if (v_type_i != v_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); - return false; - } - - // Read element size of value - uint32_t v_size_el_ref; - io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); - const size_t v_size_el = ggml_type_size(layer.v->type); - if (v_size_el != v_size_el_ref) { - LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); - return false; - } - - // Read GQA embedding size - uint32_t n_embd_v_gqa_ref; - io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); - if (n_embd_v_gqa != n_embd_v_gqa_ref) { - LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); - return false; - } - - if (cell_count) { - // For each row in the transposed matrix, read the values for the whole cell range - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const size_t dst_offset = (head + j * size) * v_size_el; - ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); - } - } - } - } - - return true; -} - -// -// llama_kv_cache_unified_iswa -// - -llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool v_trans, - bool offload, - bool swa_full, - uint32_t kv_size, - uint32_t n_seq_max, - uint32_t n_batch, - uint32_t n_pad) : hparams(model.hparams) { - llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); }; - llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); }; - - const uint32_t size_base = kv_size; - - uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad)); - - // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size and disable pruning - if (swa_full) { - LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n", - __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); - - size_swa = size_base; - do_prune = false; - } - - LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base); - - kv_base = std::make_unique( - model, std::move(filter_base), type_k, type_v, - v_trans, offload, size_base, n_seq_max, n_pad, - 0, LLAMA_SWA_TYPE_NONE); - - LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); - - kv_swa = std::make_unique( - model, std::move(filter_swa), type_k, type_v, - v_trans, offload, size_swa, n_seq_max, n_pad, - hparams.n_swa, hparams.swa_type); -} - -void llama_kv_cache_unified_iswa::clear() { - kv_base->clear(); - kv_swa ->clear(); -} - -bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - bool res = true; - - res = res & kv_base->seq_rm(seq_id, p0, p1); - res = res & kv_swa ->seq_rm(seq_id, p0, p1); - - return res; -} - -void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1); - kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1); -} - -void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) { - kv_base->seq_keep(seq_id); - kv_swa ->seq_keep(seq_id); -} - -void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - kv_base->seq_add(seq_id, p0, p1, delta); - kv_swa ->seq_add(seq_id, p0, p1, delta); -} - -void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { - kv_base->seq_div(seq_id, p0, p1, d); - kv_swa ->seq_div(seq_id, p0, p1, d); -} - -llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const { - // the base cache is a superset of the SWA cache, so we can just check the SWA cache - return kv_swa->seq_pos_min(seq_id); -} - -llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const { - return kv_swa->seq_pos_max(seq_id); -} - -void llama_kv_cache_unified_iswa::restore() { - kv_base->restore(); - kv_swa ->restore(); -} - -void llama_kv_cache_unified_iswa::commit() { - kv_base->commit(); - kv_swa ->commit(); - - // slide the attention window, forgetting/pruning old tokens that are outside the window - if (do_prune) { - for (const auto & [seq_id, entry] : pending.pos) { - kv_swa->prune_swa(seq_id, entry.pmin, entry.pmax); - } - - } - - pending.clear(); -} - -bool llama_kv_cache_unified_iswa::update(llama_context & lctx) { - bool res = true; - - res = res & kv_base->update(lctx); - res = res & kv_swa ->update(lctx); - - return res; -} - -void llama_kv_cache_unified_iswa::defrag_sched(float thold) { - kv_base->defrag_sched(thold); - kv_swa ->defrag_sched(thold); -} - -void llama_kv_cache_unified_iswa::set_full() { - kv_base->set_full(); - kv_swa ->set_full(); -} - -llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) { - pending.clear(); - - if (do_prune) { - for (int i = 0; i < batch.n_tokens; ++i) { - for (int s = 0; s < batch.n_seq_id[i]; ++s) { - const llama_seq_id seq_id = batch.seq_id[i][s]; - const llama_pos pos = batch.pos[i]; - - if (pending.pos.find(seq_id) == pending.pos.end()) { - pending.pos[seq_id].pmin = pos; - pending.pos[seq_id].pmax = pos; - } else { - pending.pos[seq_id].pmin = std::min(pending.pos[seq_id].pmin, pos); - pending.pos[seq_id].pmax = std::max(pending.pos[seq_id].pmax, pos); - } - } - } - } - - return llama_sbatch(batch, hparams.n_embd, true, logits_all); -} - -llama_ubatch llama_kv_cache_unified_iswa::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const { - GGML_UNUSED(embd_pooled); - return sbatch.split_simple(n_ubatch); -} - -bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) { - bool res = true; - - res = res & kv_base->find_slot(batch); - res = res & kv_swa ->find_slot(batch); - - return res; -} - -bool llama_kv_cache_unified_iswa::get_can_shift() const { - return kv_base->get_size() == kv_swa->get_size(); -} - -void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { - kv_base->state_write(io, seq_id); - kv_swa ->state_write(io, seq_id); -} - -void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) { - kv_base->state_read(io, seq_id); - kv_swa ->state_read(io, seq_id); -} - -llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_base() const { - return kv_base.get(); -} - -llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_swa() const { - return kv_swa.get(); -} - -// -// llama_kv_cache_recurrent -// - -llama_kv_cache_recurrent::llama_kv_cache_recurrent( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool offload, - uint32_t kv_size, - uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { - const int32_t n_layer = hparams.n_layer; - - LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n", - __func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer); - - head = 0; - size = kv_size; - used = 0; - - cells.clear(); - cells.resize(kv_size); - - // create a context for each buffer type - std::map ctx_map; - auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { - auto it = ctx_map.find(buft); - if (it == ctx_map.end()) { - ggml_init_params params = { - /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()), - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - - ggml_context * ctx = ggml_init(params); - if (!ctx) { - return nullptr; - } - - ctx_map[buft] = ctx; - ctxs.emplace_back(ctx); - - return ctx; - } - - return it->second; - }; - - k_l.reserve(n_layer); - v_l.reserve(n_layer); - - for (int i = 0; i < n_layer; i++) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); - - const char * dev_name = "CPU"; - - ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); - - if (offload) { - auto * dev = model.dev_layer(i); - buft = ggml_backend_dev_buffer_type(dev); - - dev_name = ggml_backend_dev_name(dev); - } - - LLAMA_LOG_DEBUG("%s, layer %3d: dev = %s\n", __func__, i, dev_name); - - ggml_context * ctx = ctx_for_buft(buft); - if (!ctx) { - throw std::runtime_error("failed to create ggml context for kv cache"); - } - - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); - ggml_format_name(k, "cache_k_l%d", i); - ggml_format_name(v, "cache_v_l%d", i); - k_l.push_back(k); - v_l.push_back(v); - } - - // allocate tensors and initialize the buffers to avoid NaNs in the padding - for (auto it : ctx_map) { - auto * buft = it.first; - auto * ctx = it.second; - - ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); - if (!buf) { - throw std::runtime_error("failed to allocate buffer for kv cache"); - } - ggml_backend_buffer_clear(buf, 0); - LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); - bufs.emplace_back(buf); - } - - { - const size_t memory_size_k = size_k_bytes(); - const size_t memory_size_v = size_v_bytes(); - - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, - (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), - ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), - ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); - } -} - -void llama_kv_cache_recurrent::clear() { - for (int32_t i = 0; i < (int32_t) size; ++i) { - cells[i].pos = -1; - cells[i].seq_id.clear(); - cells[i].src = -1; - cells[i].tail = -1; - } - head = 0; - used = 0; - - for (auto & buf : bufs) { - ggml_backend_buffer_clear(buf.get(), 0); - } -} - -bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - uint32_t new_head = size; - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - // models like Mamba or RWKV can't have a state partially erased - if (seq_id >= (int64_t) size) { - // could be fatal - return false; - } - if (0 <= seq_id) { - int32_t & tail_id = cells[seq_id].tail; - if (tail_id >= 0) { - const kv_cell & cell = cells[tail_id]; - // partial intersection is invalid - if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { - return false; - } - // invalidate tails which will be cleared - if (p0 <= cell.pos && cell.pos < p1) { - tail_id = -1; - } - } - } else { - // seq_id is negative, then the range should include everything or nothing - if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { - return false; - } - } - - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].pos >= p0 && cells[i].pos < p1) { - if (seq_id < 0) { - cells[i].seq_id.clear(); - } else if (cells[i].has_seq_id(seq_id)) { - cells[i].seq_id.erase(seq_id); - } else { - continue; - } - if (cells[i].is_empty()) { - // keep count of the number of used cells - if (cells[i].pos >= 0) { - used--; - } - cells[i].pos = -1; - cells[i].src = -1; - if (new_head == size) { - new_head = i; - } - } - } - } - - // If we freed up a slot, set head to it so searching can start there. - if (new_head != size && new_head < head) { - head = new_head; - } - - return true; -} - -void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - if (seq_id_src == seq_id_dst) { - return; - } - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) { - kv_cell & tail_src = cells[seq_id_src]; - kv_cell & tail_dst = cells[seq_id_dst]; - if (tail_dst.tail >= 0) { - // clear destination seq_id if it wasn't empty - kv_cell & cell_dst = cells[tail_dst.tail]; - - cell_dst.seq_id.erase(seq_id_dst); - tail_dst.tail = -1; - if (cell_dst.seq_id.empty()) { - cell_dst.pos = -1; - cell_dst.src = -1; - used -= 1; - } - } - if (tail_src.tail >= 0) { - kv_cell & cell_src = cells[tail_src.tail]; - - cell_src.seq_id.insert(seq_id_dst); - tail_dst.tail = tail_src.tail; - } - } -} - -void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) { - uint32_t new_head = size; - - for (uint32_t i = 0; i < size; ++i) { - if ((llama_seq_id) i != seq_id) { - cells[i].tail = -1; - } - - if (!cells[i].has_seq_id(seq_id)) { - if (cells[i].pos >= 0) { - used--; - } - - cells[i].pos = -1; - cells[i].src = -1; - cells[i].seq_id.clear(); - - if (new_head == size){ - new_head = i; - } - } else { - cells[i].seq_id.clear(); - cells[i].seq_id.insert(seq_id); - } - } - - // If we freed up a slot, set head to it so searching can start there. - if (new_head != size && new_head < head) { - head = new_head; - } -} - -void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - if (delta == 0) { - return; - } - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - // If there is no range then return early to avoid looping over the - if (p0 == p1) { - return; - } - - // for Mamba-like or RWKV models, only the pos needs to be shifted - if (0 <= seq_id && seq_id < (int64_t) size) { - const int32_t tail_id = cells[seq_id].tail; - if (tail_id >= 0) { - kv_cell & cell = cells[tail_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos += delta; - } - } - } -} - -void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { - if (d == 1) { - return; - } - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - // If there is no range then return early to avoid looping over the cache. - if (p0 == p1) { - return; - } - - // for Mamba-like or RWKV models, only the pos needs to be changed - if (0 <= seq_id && seq_id < (int64_t) size) { - const int32_t tail_id = cells[seq_id].tail; - if (tail_id >= 0) { - kv_cell & cell = cells[tail_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos /= d; - } - } - } -} - -llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const { - llama_pos result = std::numeric_limits::max(); - - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].has_seq_id(seq_id)) { - result = std::min(result, cells[i].pos); - } - } - - if (result == std::numeric_limits::max()) { - result = -1; - } - - return result; -} - -llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const { - llama_pos result = -1; - - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].has_seq_id(seq_id)) { - result = std::max(result, cells[i].pos); - } - } - - return result; -} - -void llama_kv_cache_recurrent::restore() { - if (pending.ranges.empty()) { - return; - } - - seq_rm(-1, -1, -1); -} - -void llama_kv_cache_recurrent::commit() { - pending.ranges.clear(); -} - -bool llama_kv_cache_recurrent::update(llama_context & ctx) { - GGML_UNUSED(ctx); - return false; -} - -void llama_kv_cache_recurrent::defrag_sched(float thold) { - GGML_UNUSED(thold); - // noop -} - -void llama_kv_cache_recurrent::set_full() { - n = size; - head = 0; -} - -llama_sbatch llama_kv_cache_recurrent::sbatch_init( - const llama_batch & batch, - bool logits_all) { - return llama_sbatch(batch, hparams.n_embd, false, logits_all); -} - -llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const { - if (embd_pooled) { - // Pooled embeddings cannot be split across ubatches (yet) - return sbatch.split_seq(n_ubatch); - } - - return sbatch.split_equal(n_ubatch); -} - -bool llama_kv_cache_recurrent::find_slot( - const llama_ubatch & ubatch) { - const uint32_t n_tokens = ubatch.n_tokens; - const uint32_t n_seqs = ubatch.n_seqs; - - const uint32_t n_seq_tokens = ubatch.n_seq_tokens; - - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (head > used + 2*n_tokens) { - head = 0; - } - - // For recurrent state architectures (like Mamba or RWKV), - // each cache cell can store the state for a whole sequence. - // A slot should be always be contiguous. - - // can only process batches with an equal number of new tokens in each sequence - GGML_ASSERT(ubatch.equal_seqs); - - int32_t min = size - 1; - int32_t max = 0; - - // everything should fit if all seq_ids are smaller than the max - for (uint32_t s = 0; s < n_seqs; ++s) { - const uint32_t n_seq_id = ubatch.n_seq_id[s]; - for (uint32_t j = 0; j < n_seq_id; ++j) { - const llama_seq_id seq_id = ubatch.seq_id[s][j]; - - if (seq_id < 0 || (uint32_t) seq_id >= size) { - // too big seq_id - // TODO: would it be possible to resize the cache instead? - LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max); - return false; - } - if (j > 0) { - kv_cell & seq = cells[seq_id]; - if (seq.tail >= 0) { - kv_cell & cell = cells[seq.tail]; - // clear cells from seq_ids that become shared - // (should not normally happen, but let's handle it anyway) - cell.seq_id.erase(seq_id); - seq.tail = -1; - if (cell.seq_id.empty()) { - cell.pos = -1; - cell.src = -1; - used -= 1; - } - } - } - } - } - -#ifndef NDEBUG - { - std::vector tails_verif; - tails_verif.assign(size, -1); - for (uint32_t i = 0; i < size; ++i) { - kv_cell & cell = cells[i]; - for (llama_seq_id seq_id : cell.seq_id) { - if (tails_verif[seq_id] != -1) { - LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]); - } - tails_verif[seq_id] = i; - } - } - for (uint32_t i = 0; i < size; ++i) { - if (tails_verif[i] != cells[i].tail) { - LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]); - } - } - } -#endif - - // find next empty cell - uint32_t next_empty_cell = head; - - for (uint32_t i = 0; i < size; ++i) { - if (next_empty_cell >= size) { next_empty_cell -= size; } - kv_cell & cell = cells[next_empty_cell]; - if (cell.is_empty()) { break; } - next_empty_cell += 1; - } - - // find usable cell range - for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch.seq_id[s][0]; - kv_cell & seq_meta = cells[seq_id]; - bool has_cell = false; - if (seq_meta.tail >= 0) { - kv_cell & cell = cells[seq_meta.tail]; - GGML_ASSERT(cell.has_seq_id(seq_id)); - // does this seq_id "own" the cell? - if (cell.seq_id.size() == 1) { has_cell = true; } - } - if (!has_cell) { - kv_cell & empty_cell = cells[next_empty_cell]; - GGML_ASSERT(empty_cell.is_empty()); - // copy old tail into the empty cell - if (seq_meta.tail >= 0) { - kv_cell & orig_cell = cells[seq_meta.tail]; - empty_cell.pos = orig_cell.pos; - empty_cell.src = orig_cell.src; - orig_cell.seq_id.erase(seq_id); - empty_cell.seq_id.insert(seq_id); // will be overwritten - } - seq_meta.tail = next_empty_cell; - // find next empty cell - if (s + 1 < n_seqs) { - next_empty_cell += 1; - for (uint32_t i = 0; i < size; ++i) { - if (next_empty_cell >= size) { next_empty_cell -= size; } - kv_cell & cell = cells[next_empty_cell]; - if (cell.is_empty()) { break; } - next_empty_cell += 1; - } - } - } - if (min > seq_meta.tail) { min = seq_meta.tail; } - if (max < seq_meta.tail) { max = seq_meta.tail; } - } - - // gather and re-order - for (uint32_t s = 0; s < n_seqs; ++s) { - int32_t dst_id = s + min; - int32_t src_id = cells[ubatch.seq_id[s][0]].tail; - if (dst_id != src_id) { - kv_cell & dst_cell = cells[dst_id]; - kv_cell & src_cell = cells[src_id]; - - std::swap(dst_cell.pos, src_cell.pos); - std::swap(dst_cell.src, src_cell.src); - std::swap(dst_cell.seq_id, src_cell.seq_id); - - // swap tails (assuming they NEVER overlap) - for (const llama_seq_id seq_id : src_cell.seq_id) { - cells[seq_id].tail = src_id; - } - for (const llama_seq_id seq_id : dst_cell.seq_id) { - cells[seq_id].tail = dst_id; - } - } - } - - // update the pos of the used seqs - for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; - int32_t cell_id = s + min; - kv_cell & cell = cells[cell_id]; - - if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { - // What should happen when the pos backtracks or skips a value? - // Clearing the state mid-batch would require special-casing which isn't done. - LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", - __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens); - } - cell.pos = last_pos; - cell.seq_id.clear(); - for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) { - const llama_seq_id seq_id = ubatch.seq_id[s][j]; - cell.seq_id.insert(seq_id); - cells[seq_id].tail = cell_id; - } - } - - // allow getting the range of used cells, from head to head + n - head = min; - n = max - min + 1; - used = std::count_if(cells.begin(), cells.end(), - [](const kv_cell & cell){ return !cell.is_empty(); }); - - // sanity check - return n >= n_seqs; -} - -bool llama_kv_cache_recurrent::get_can_shift() const { - return false; -} - -int32_t llama_kv_cache_recurrent::s_copy(int i) const { - const uint32_t cell_id = i + head; - - ////////////////////////////////////////////// - // TODO: this should not mutate the KV cache ! - kv_cell & cell = const_cast(cells[cell_id]); - - // prevent out-of-bound sources - if (cell.src < 0 || (uint32_t) cell.src >= size) { - cell.src = cell_id; - } - - int32_t res = cell.src; - - // TODO: do not mutate the KV cache - // ensure copy only happens once - if (cell.src != (int32_t) cell_id) { - cell.src = cell_id; - } - - return res; -} - -float llama_kv_cache_recurrent::s_mask(int i) const { - const uint32_t cell_id = i + head; - - ////////////////////////////////////////////// - // TODO: this should not mutate the KV cache ! - kv_cell & cell = const_cast(cells[cell_id]); - - float res = (float) (cell.src >= 0); - - // only clear once - if (cell.src < 0) { - cell.src = cell_id; - } - - return res; -} - -uint32_t llama_kv_cache_recurrent::cell_max() const { - for (uint32_t i = size; i > 0; --i) { - const kv_cell & cell = cells[i - 1]; - - if (cell.pos >= 0 && !cell.is_empty()) { - return i; - } - } - - return 0; -} - -size_t llama_kv_cache_recurrent::total_size() const { - size_t size = 0; - for (const auto & buf : bufs) { - size += ggml_backend_buffer_get_size(buf.get()); - } - - return size; -} - -size_t llama_kv_cache_recurrent::size_k_bytes() const { - size_t size_k_bytes = 0; - - for (const auto & k : k_l) { - size_k_bytes += ggml_nbytes(k); - } - - return size_k_bytes; -} - -size_t llama_kv_cache_recurrent::size_v_bytes() const { - size_t size_v_bytes = 0; - - for (const auto & v : v_l) { - size_v_bytes += ggml_nbytes(v); - } - - return size_v_bytes; -} - -void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { - std::vector> cell_ranges; // ranges, from inclusive, to exclusive - uint32_t cell_count = 0; - - // Count the number of cells with the specified seq_id - // Find all the ranges of cells with this seq id (or all, when -1) - uint32_t cell_range_begin = size; - for (uint32_t i = 0; i < size; ++i) { - const auto & cell = cells[i]; - if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { - ++cell_count; - if (cell_range_begin == size) { - cell_range_begin = i; - } - } else { - if (cell_range_begin != size) { - cell_ranges.emplace_back(cell_range_begin, i); - cell_range_begin = size; - } - } - } - if (cell_range_begin != size) { - cell_ranges.emplace_back(cell_range_begin, size); - } - - // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count - uint32_t cell_count_check = 0; - for (const auto & range : cell_ranges) { - cell_count_check += range.second - range.first; - } - GGML_ASSERT(cell_count == cell_count_check); - - io.write(&cell_count, sizeof(cell_count)); - - state_write_meta(io, cell_ranges, seq_id); - state_write_data(io, cell_ranges); -} - -void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) { - uint32_t cell_count; - io.read_to(&cell_count, sizeof(cell_count)); - - bool res = true; - - res = res && state_read_meta(io, cell_count, seq_id); - res = res && state_read_data(io, cell_count); - - if (!res) { - if (seq_id == -1) { - clear(); - } else { - seq_rm(seq_id, -1, -1); - } - throw std::runtime_error("failed to restore kv cache"); - } -} - -void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { - for (const auto & range : cell_ranges) { - for (uint32_t i = range.first; i < range.second; ++i) { - const auto & cell = cells[i]; - const llama_pos pos = cell.pos; - const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0; - - io.write(&pos, sizeof(pos)); - io.write(&n_seq_id, sizeof(n_seq_id)); - - if (n_seq_id) { - for (auto seq_id : cell.seq_id) { - io.write(&seq_id, sizeof(seq_id)); - } - } - } - } -} - -void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { - const uint32_t v_trans = 0; - const uint32_t n_layer = hparams.n_layer; - - io.write(&v_trans, sizeof(v_trans)); - io.write(&n_layer, sizeof(n_layer)); - - std::vector tmp_buf; - - // Iterate and write all the keys first, each row is a cell - // Get whole range at a time - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); - - // Write key type - const int32_t k_type_i = (int32_t)k_l[il]->type; - io.write(&k_type_i, sizeof(k_type_i)); - - // Write row size of key - const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); - io.write(&k_size_row, sizeof(k_size_row)); - - // Read each range of cells of k_size length each into tmp_buf and write out - for (const auto & range : cell_ranges) { - const size_t range_size = range.second - range.first; - const size_t buf_size = range_size * k_size_row; - io.write_tensor(k_l[il], range.first * k_size_row, buf_size); - } - } - - if (!v_trans) { - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Write value type - const int32_t v_type_i = (int32_t)v_l[il]->type; - io.write(&v_type_i, sizeof(v_type_i)); - - // Write row size of value - const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); - io.write(&v_size_row, sizeof(v_size_row)); - - // Read each range of cells of v_size length each into tmp_buf and write out - for (const auto & range : cell_ranges) { - const size_t range_size = range.second - range.first; - const size_t buf_size = range_size * v_size_row; - io.write_tensor(v_l[il], range.first * v_size_row, buf_size); - } - } - } else { - // When v is transposed, we also need the element size and get the element ranges from each row - const uint32_t kv_size = size; - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Write value type - const int32_t v_type_i = (int32_t)v_l[il]->type; - io.write(&v_type_i, sizeof(v_type_i)); - - // Write element size - const uint32_t v_size_el = ggml_type_size(v_l[il]->type); - io.write(&v_size_el, sizeof(v_size_el)); - - // Write GQA embedding size - io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); - - // For each row, we get the element values of each cell - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - // Read each range of cells of v_size_el length each into tmp_buf and write out - for (const auto & range : cell_ranges) { - const size_t range_size = range.second - range.first; - const size_t src_offset = (range.first + j * kv_size) * v_size_el; - const size_t buf_size = range_size * v_size_el; - io.write_tensor(v_l[il], src_offset, buf_size); - } - } - } - } -} - -bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { - if (dest_seq_id != -1) { - // single sequence - - seq_rm(dest_seq_id, -1, -1); - - llama_sbatch sbatch; - llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); - - batch.n_tokens = cell_count; - batch.n_seq_tokens = cell_count; - batch.n_seqs = 1; - - for (uint32_t i = 0; i < cell_count; ++i) { - llama_pos pos; - uint32_t n_seq_id; - - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); - - if (n_seq_id != 0) { - LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); - return false; - } - - batch.pos[i] = pos; - } - batch.n_seq_id[0] = 1; - batch.seq_id[0] = &dest_seq_id; - if (!find_slot(batch)) { - LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); - return false; - } - commit(); - - // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) - // Assume that this is one contiguous block of cells - GGML_ASSERT(head + cell_count <= size); - GGML_ASSERT(cells[head].pos == batch.pos[0]); - GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]); - GGML_ASSERT(cells[head].has_seq_id(dest_seq_id)); - GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id)); - } else { - // whole KV cache restore - - if (cell_count > size) { - LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); - return false; - } - - clear(); - - for (uint32_t i = 0; i < cell_count; ++i) { - kv_cell & cell = cells[i]; - - llama_pos pos; - uint32_t n_seq_id; - - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); - - cell.pos = pos; - - for (uint32_t j = 0; j < n_seq_id; ++j) { - llama_seq_id seq_id; - io.read_to(&seq_id, sizeof(seq_id)); - - // TODO: llama_kv_cache_recurrent should have a notion of max sequences - //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { - if (seq_id < 0) { - //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); - LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id); - return false; - } - - cell.seq_id.insert(seq_id); - - int32_t & tail = cells[seq_id].tail; - if (tail != -1) { - LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); - return false; - } - tail = i; - } - } - - head = 0; - used = cell_count; - } - - for (uint32_t i = 0; i < cell_count; ++i) { - uint32_t cell_id = head + i; - // make sure the recurrent states will keep their restored state - cells[cell_id].src = cell_id; - } - - return true; -} - -bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) { - uint32_t v_trans; - uint32_t n_layer; - io.read_to(&v_trans, sizeof(v_trans)); - io.read_to(&n_layer, sizeof(n_layer)); - - if (n_layer != hparams.n_layer) { - LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); - return false; - } - if (cell_count > size) { - LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size); - return false; - } - if (false != (bool) v_trans) { - LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); - return false; - } - - // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); - - // Read type of key - int32_t k_type_i_ref; - io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); - const int32_t k_type_i = (int32_t) k_l[il]->type; - if (k_type_i != k_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); - return false; - } - - // Read row size of key - uint64_t k_size_row_ref; - io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); - const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); - if (k_size_row != k_size_row_ref) { - LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); - return false; - } - - if (cell_count) { - // Read and set the keys for the whole cell range - ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); - } - } - - if (!v_trans) { - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Read type of value - int32_t v_type_i_ref; - io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)v_l[il]->type; - if (v_type_i != v_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); - return false; - } - - // Read row size of value - uint64_t v_size_row_ref; - io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); - const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); - if (v_size_row != v_size_row_ref) { - LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); - return false; - } - - if (cell_count) { - // Read and set the values for the whole cell range - ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); - } - } - } else { - // For each layer, read the values for each cell (transposed) - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Read type of value - int32_t v_type_i_ref; - io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)v_l[il]->type; - if (v_type_i != v_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); - return false; - } - - // Read element size of value - uint32_t v_size_el_ref; - io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); - const size_t v_size_el = ggml_type_size(v_l[il]->type); - if (v_size_el != v_size_el_ref) { - LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); - return false; - } - - // Read GQA embedding size - uint32_t n_embd_v_gqa_ref; - io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); - if (n_embd_v_gqa != n_embd_v_gqa_ref) { - LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); - return false; - } - - if (cell_count) { - // For each row in the transposed matrix, read the values for the whole cell range - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const size_t dst_offset = (head + j * size) * v_size_el; - ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); - } - } - } - } - - return true; -} diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 191a1090a..2d04705f2 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -2,58 +2,34 @@ #include "llama.h" #include "llama-io.h" -#include "llama-graph.h" #include "llama-memory.h" -#include "ggml-cpp.h" - -#include -#include -#include - -struct llama_cparams; -struct llama_hparams; -struct llama_ubatch; -struct llama_sbatch; -struct llama_model; -struct llama_context; - struct llama_kv_cache : public llama_memory_i { virtual ~llama_kv_cache() = default; - // call if batch processing fails - restores the cache state - virtual void restore() = 0; + // split the input batch into a set of ubatches and verify that they can fit into the cache + // return a state object containing the ubatches and KV cache state required to process them + // check the llama_memory_state_i::get_status() for the result + virtual llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) = 0; - // call after successful batch processing - clears any pending state - virtual void commit() = 0; + // simulate full cache, used for allocating worst-case compute buffers + virtual llama_memory_state_ptr init_full() = 0; // process any pending defrag/shift/etc. operations // optionally call once before processing a new batch + // return true if any operations were performed virtual bool update(llama_context & lctx) = 0; // schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing + // TODO: change to + // llama_memory_state_ptr init_defrag(float thold) = 0; + // virtual void defrag_sched(float thold) = 0; - // simulate full cache, used for allocating worst-case compute buffers - virtual void set_full() = 0; - - // - // batch processing - // - - // ============================================================================================================= - // TODO: refactor and simplify this - - virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0; - - // different KV caches require different batch splitting strategies - virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0; - - // find an empty slot of size "n_tokens" in the cache - virtual bool find_slot(const llama_ubatch & batch) = 0; - - // ============================================================================================================= - // getters virtual bool get_can_shift() const = 0; @@ -66,450 +42,3 @@ struct llama_kv_cache : public llama_memory_i { virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0; virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0; }; - -// -// llama_kv_cache_guard -// - -struct llama_kv_cache_guard { - llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {} - - ~llama_kv_cache_guard() { - kv->restore(); - } - - void commit() { - kv->commit(); - } - -private: - llama_kv_cache * kv; -}; - -// -// llama_kv_cache_unified -// - -class llama_kv_cache_unified : public llama_kv_cache { -public: - static uint32_t get_padding(const llama_cparams & cparams); - - // this callback is used to filter out layers that should not be included in the cache - using layer_filter_cb = std::function; - - llama_kv_cache_unified( - const llama_model & model, - layer_filter_cb && filter, - ggml_type type_k, - ggml_type type_v, - bool v_trans, - bool offload, - uint32_t kv_size, - uint32_t n_seq_max, - uint32_t n_pad, - uint32_t n_swa, - llama_swa_type swa_type); - - ~llama_kv_cache_unified() = default; - - // - // llama_memory_i - // - - void clear() override; - - bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; - void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; - void seq_keep(llama_seq_id seq_id) override; - void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override; - void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; - - llama_pos seq_pos_min(llama_seq_id seq_id) const override; - llama_pos seq_pos_max(llama_seq_id seq_id) const override; - - // - // llama_kv_cache - // - - void restore() override; - void commit() override; - - bool update(llama_context & ctx) override; - - void defrag_sched(float thold) override; - - void set_full() override; - - llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; - llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; - - // updates the cache head - // Note: On success, it's important that cache.head points - // to the first cell of the slot. - bool find_slot(const llama_ubatch & batch) override; - - bool get_can_shift() const override; - - // state write/load - - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; - - // - // llama_kv_cache_unified specific API - // - - uint32_t get_n() const; - uint32_t get_size() const; - - // get views of the current state of the cache - ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; - ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; - - // store k_cur and v_cur in the cache based on the current head location - ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const; - ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const; - - void prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax); - - void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; - void set_input_k_shift (ggml_tensor * dst) const; - void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; - -private: - const llama_model & model; - const llama_hparams & hparams; - - struct kv_cell { - llama_pos pos = -1; - llama_pos delta = 0; - - // TODO: replace with bitset uint64_t - std::set seq_id; - - bool has_seq_id(const llama_seq_id & id) const { - return seq_id.find(id) != seq_id.end(); - } - - bool is_empty() const { - return seq_id.empty(); - } - - bool is_same_seq(const kv_cell & other) const { - return seq_id == other.seq_id; - } - }; - - struct kv_layer { - // layer index in the model - // note: can be different from the layer index in the KV cache - uint32_t il; - - ggml_tensor * k; - ggml_tensor * v; - }; - - bool has_shift = false; - bool do_defrag = false; - bool v_trans = true; // the value tensor is transposed - - uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) - uint32_t size = 0; // total number of cells, shared across all sequences - uint32_t used = 0; // used cells (i.e. at least one seq_id) (TODO: add `struct kv_cells` and keep track automaticallt) - - // computed before each graph build - uint32_t n = 0; - - const uint32_t n_seq_max = 1; - - // required padding - const uint32_t n_pad = 1; - - // SWA - const uint32_t n_swa = 0; - - const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; - - std::vector ctxs; - std::vector bufs; - - std::vector cells; // TODO: replace with `struct kv_cells` - std::vector layers; - - // model layer id -> KV cache layer id - std::unordered_map map_layer_ids; - - // recovery information used to restore the KV cells to their original state in case of a failure - struct { - void clear() { - cells.clear(); - } - - std::unordered_map cells; - } recovery; - - // defrag - struct { - std::vector ids; - } defrag_info; - - // return true if cells have been moved - bool defrag_prepare(int32_t n_max_nodes); - - // find how many cells are currently in use - uint32_t cell_max() const; - - size_t total_size() const; - - size_t size_k_bytes() const; - size_t size_v_bytes() const; - - bool is_masked_swa(llama_pos p0, llama_pos p1) const; - - ggml_tensor * build_rope_shift( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_tensor * cur, - ggml_tensor * shift, - ggml_tensor * factors, - float freq_base, - float freq_scale) const; - - llm_graph_result_ptr build_graph_shift( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_cgraph * gf) const; - - llm_graph_result_ptr build_graph_defrag( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_cgraph * gf) const; - - void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; - void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; - - bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); - bool state_read_data(llama_io_read_i & io, uint32_t cell_count); -}; - -// -// llama_kv_cache_unified_iswa -// - -// utilizes two instances of llama_kv_cache_unified -// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers -// upon successful commit, the SWA cache removes old tokens outside the n_swa window - -class llama_kv_cache_unified_iswa : public llama_kv_cache { -public: - llama_kv_cache_unified_iswa( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool v_trans, - bool offload, - bool swa_full, - uint32_t kv_size, - uint32_t n_seq_max, - uint32_t n_batch, - uint32_t n_pad); - - ~llama_kv_cache_unified_iswa() = default; - - // - // llama_memory_i - // - - void clear() override; - - bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; - void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; - void seq_keep(llama_seq_id seq_id) override; - void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override; - void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; - - llama_pos seq_pos_min(llama_seq_id seq_id) const override; - llama_pos seq_pos_max(llama_seq_id seq_id) const override; - - // - // llama_kv_cache - // - - void restore() override; - void commit() override; - - bool update(llama_context & ctx) override; - - void defrag_sched(float thold) override; - - void set_full() override; - - llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; - llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; - - bool find_slot(const llama_ubatch & batch) override; - - bool get_can_shift() const override; - - // state write/load - - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; - - // - // llama_kv_cache_unified_iswa specific API - // - - llama_kv_cache_unified * get_kv_base() const; - llama_kv_cache_unified * get_kv_swa () const; - -private: - const llama_hparams & hparams; - - bool do_prune = true; - - struct { - struct entry { - llama_pos pmin; - llama_pos pmax; - }; - - void clear() { - pos.clear(); - } - - // used to perform SWA pruning of old tokens - std::unordered_map pos; - } pending; - - std::unique_ptr kv_base; - std::unique_ptr kv_swa; -}; - -// -// llama_kv_cache_recurrent -// - -class llama_kv_cache_recurrent : public llama_kv_cache { -public: - struct kv_cell { - llama_pos pos = -1; - int32_t src = -1; // used to copy states - int32_t tail = -1; - - std::set seq_id; - - bool has_seq_id(const llama_seq_id & id) const { - return seq_id.find(id) != seq_id.end(); - } - - bool is_empty() const { - return seq_id.empty(); - } - - bool is_same_seq(const kv_cell & other) const { - return seq_id == other.seq_id; - } - }; - - llama_kv_cache_recurrent( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool offload, - uint32_t kv_size, - uint32_t n_seq_max); - - ~llama_kv_cache_recurrent() = default; - - // - // llama_memory_i - // - - void clear() override; - - bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; - void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; - void seq_keep(llama_seq_id seq_id) override; - void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override; - void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; - - llama_pos seq_pos_min(llama_seq_id seq_id) const override; - llama_pos seq_pos_max(llama_seq_id seq_id) const override; - - // - // llama_kv_cache - // - - void restore() override; - void commit() override; - - bool update(llama_context & ctx) override; - - void defrag_sched(float thold) override; - - void set_full() override; - - llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; - llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; - - bool find_slot(const llama_ubatch & batch) override; - - bool get_can_shift() const override; - - // TODO: temporary methods - they are not really const as they do const_cast<>, fix this - int32_t s_copy(int i) const; - float s_mask(int i) const; - - // state write/load - - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; - - uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) - uint32_t size = 0; // total number of cells, shared across all sequences - uint32_t used = 0; // used cells (i.e. at least one seq_id) - - // computed before each graph build - uint32_t n = 0; - - std::vector cells; - - std::vector k_l; // per layer - std::vector v_l; - -private: - //const llama_model & model; - const llama_hparams & hparams; - - // commit/restore cache - // TODO: rework for recurrent cache - struct slot_range { - uint32_t c0 = 0; // note: these are cell indices, not sequence positions - uint32_t c1 = 0; - }; - - // pending cell updates that are not yet committed - struct { - std::vector ranges; - } pending; - - const uint32_t n_seq_max = 1; - - std::vector ctxs; - std::vector bufs; - - // find how many cells are currently in use - uint32_t cell_max() const; - - size_t total_size() const; - - size_t size_k_bytes() const; - size_t size_v_bytes() const; - - void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; - void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; - - bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); - bool state_read_data(llama_io_read_i & io, uint32_t cell_count); -}; diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h new file mode 100644 index 000000000..9e2c4d927 --- /dev/null +++ b/src/llama-kv-cells.h @@ -0,0 +1,410 @@ +#pragma once + +#include "llama.h" +#include "llama-cparams.h" + +#include +#include +#include +#include + +// meta information about KV cells that can be part of multiple sequences at the same time +// TODO: add unit tests +class llama_kv_cells_unified { +public: + void reset() { + for (uint32_t i = 0; i < pos.size(); ++i) { + pos[i] = -1; + shift[i] = 0; + seq[i].reset(); + } + + has_shift = false; + + used.clear(); + + for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + seq_pos[s].clear(); + } + } + + void reset_shift() { + has_shift = false; + + for (uint32_t i = 0; i < shift.size(); ++i) { + shift[i] = 0; + } + } + + uint32_t size() const { + return pos.size(); + } + + void resize(uint32_t n) { + pos.resize(n); + shift.resize(n); + seq.resize(n); + + reset(); + } + + bool is_empty(uint32_t i) const { + assert(i < pos.size()); + assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0); + + return pos[i] == -1; + } + + uint32_t get_used() const { + return used.size(); + } + + // the index of the first cell that is used + // return 0 if no cells are used + uint32_t used_min() const { + return used.empty() ? 0 : *used.begin(); + } + + // the index of the last cell that is used + 1 + // return 0 if no cells are used + uint32_t used_max_p1() const { + return used.empty() ? 0 : *used.rbegin() + 1; + } + + bool get_has_shift() const { + return has_shift; + } + + // move cell isrc to idst (used during defrag) + void mv(uint32_t isrc, uint32_t idst) { + assert(isrc < pos.size()); + assert(idst < pos.size()); + + pos [idst] = pos [isrc]; + shift[idst] = shift[isrc]; + seq [idst] = seq [isrc]; + + pos [isrc] = -1; + shift[isrc] = 0; + seq [isrc].reset(); + + used.erase (isrc); + used.insert(idst); + } + + // copy the state of cells [i, i + n) (used for save/restore the state of the cells) + llama_kv_cells_unified cp(uint32_t i, uint32_t n) const { + assert(i + n <= pos.size()); + + llama_kv_cells_unified res; + + res.resize(n); + + for (uint32_t j = 0; j < n; ++j) { + res.pos[j] = pos[i + j]; + res.seq[j] = seq[i + j]; + + assert(shift[i + j] == 0); + } + + return res; + } + + // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells) + void set(uint32_t i, const llama_kv_cells_unified & other) { + assert(i + other.pos.size() <= pos.size()); + + for (uint32_t j = 0; j < other.pos.size(); ++j) { + if (pos[i + j] == -1 && other.pos[j] != -1) { + used.insert(i + j); + } + + if (pos[i + j] != -1 && other.pos[j] == -1) { + used.erase(i + j); + } + + if (pos[i + j] != -1) { + seq_pos_rm(i + j); + } + + pos[i + j] = other.pos[j]; + seq[i + j] = other.seq[j]; + + if (pos[i + j] != -1) { + seq_pos_add(i + j); + } + + assert(shift[i + j] == 0); + } + } + + // clear a non-empty cell + void rm(uint32_t i) { + assert(i < pos.size()); + assert(pos[i] != -1); + + seq_pos_rm(i); + + pos[i] = -1; + seq[i].reset(); + + used.erase(i); + } + + // note: call only if the cell has seq_id + // return true if the cell becomes empty + bool seq_rm(uint32_t i, llama_seq_id seq_id) { + assert(i < pos.size()); + assert(seq[i].test(seq_id)); + assert(pos[i] != -1); + assert(seq_id >= 0); + + seq[i].reset(seq_id); + seq_pos[seq_id].erase(pos[i]); + + if (seq[i].none()) { + pos[i] = -1; + + used.erase(i); + + return true; + } + + return false; + } + + // return true if the cell becomes empty (i.e. it did not contain seq_id before the call) + bool seq_keep(uint32_t i, llama_seq_id seq_id) { + assert(i < pos.size()); + + if (seq[i].test(seq_id)) { + seq_pos_rm(i); + seq[i].reset(); + + seq[i].set(seq_id); + seq_pos[seq_id].insert(pos[i]); + + return false; + } + + if (seq[i].any()) { + seq_pos_rm(i); + seq[i].reset(); + + pos[i] = -1; + + used.erase(i); + + return true; + } + + assert(pos[i] == -1); + + return false; + } + + // number of different sequences in the cell + int seq_count(uint32_t i) const { + assert(i < pos.size()); + assert(pos[i] != -1); + + return seq[i].count(); + } + + // check if the cell contains seq_id + bool seq_has(uint32_t i, llama_seq_id seq_id) const { + assert(i < pos.size()); + assert(seq_id >= 0); + + return seq[i].test(seq_id); + } + + // note: call only if the cell is not empty and the seq_id is not in the cell + void seq_add(uint32_t i, llama_seq_id seq_id) { + assert(i < pos.size()); + assert(pos[i] != -1); + assert(!seq[i].test(seq_id)); + + seq[i].set(seq_id); + seq_pos[seq_id].insert(pos[i]); + } + + // return the sequence id of this cell + // note: call only for cells with exactly one sequence + llama_seq_id seq_get(uint32_t i) const { + assert(seq[i].count() == 1); + + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (seq[i].test(s)) { + return s; + } + } + + return -1; + } + + // the minimum position of sequence seq_id currently present in any of the cells + // return -1 if the sequence is not present + llama_pos seq_pos_min(llama_seq_id seq_id) const { + assert(seq_id >= 0); + assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES); + + if (seq_pos[seq_id].empty()) { + return -1; + } + + return *seq_pos[seq_id].begin(); + } + + // the maximum position of sequence seq_id currently present in any of the cells + // return -1 if the sequence is not present + llama_pos seq_pos_max(llama_seq_id seq_id) const { + assert(seq_id >= 0); + assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES); + + if (seq_pos[seq_id].empty()) { + return -1; + } + + return *seq_pos[seq_id].rbegin(); + } + + // note: call only if the cell is not empty + llama_pos pos_get(uint32_t i) const { + assert(i < pos.size()); + assert(pos[i] != -1); + + return pos[i]; + } + + // note: call only if the cell is not empty + llama_pos get_shift(uint32_t i) const { + assert(i < pos.size()); + assert(pos[i] != -1); + + return shift[i]; + } + + // check if a cell is not empty and its position is within [p0, p1) + bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const { + assert(i < pos.size()); + + return pos[i] >= p0 && pos[i] < p1; + } + + // set the position of an empty cell + // does not modify "has_shift" + // note: call only if the cell is empty + void pos_set(uint32_t i, llama_pos p) { + assert(i < pos.size()); + assert(pos[i] == -1); + assert(seq[i].none()); + + pos[i] = p; + + used.insert(i); + } + + // pos[i] = pos[i] + d + // sets "has_shift" to true + // note: call only if the cell is not empty + bool pos_add(uint32_t i, llama_pos d) { + assert(i < pos.size()); + assert(pos[i] != -1); + + seq_pos_rm(i); + + pos[i] += d; + shift[i] += d; + + seq_pos_add(i); + + has_shift = true; + + if (pos[i] < 0) { + seq_pos_rm(i); + + seq[i].reset(); + pos[i] = -1; + + used.erase(i); + + return true; + } + + return false; + } + + // pos[i] = pos[i] / d + // sets "has_shift" to true + // note: call only if the cell is not empty + void pos_div(uint32_t i, int d) { + assert(i < pos.size()); + assert(pos[i] != -1); + + const llama_pos p_old = pos[i]; + + seq_pos_rm(i); + + pos[i] /= d; + shift[i] += p_old - pos[i]; + + seq_pos_add(i); + + has_shift = true; + } + +private: + bool has_shift = false; + + // set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id) + std::set used; + + std::vector pos; + + // this array accumulates any applied shifts to the pos array since the last reset_shift() call + // this is used to queue multiple updates to the pos array, which in the end can be applied in one go: + // + // cells.pos_add(x, shift_x); + // cells.pos_div(y, shift_y); + // ... + // + // if (cells.has_shift()) { + // for (int i = 0; i < n; ++i) { + // auto shift_i = cells.get_shift(i); + // ... + // } + // cells.reset_shift(); + // } + // + std::vector shift; + + using bits_t = std::bitset; + + // the bitset seq[i] tells us which sequences are currently occupying the i-th cell + std::vector seq; + + // the set seq_pos[s] tells us which positions are currently present for sequence s + // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache + std::set seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES]; + + // helper functions for updating `seq_pos`, once cell at a time: + + // remove cell i + void seq_pos_rm(uint32_t i) { + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (seq[i].test(s)) { + seq_pos[s].erase(pos[i]); + } + } + } + + // add cell i + void seq_pos_add(uint32_t i) { + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (seq[i].test(s)) { + seq_pos[s].insert(pos[i]); + } + } + } +}; diff --git a/src/llama-memory.h b/src/llama-memory.h index c2571edc7..b3799d66e 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -2,6 +2,11 @@ #include "llama.h" +#include +#include + +struct llama_ubatch; + struct llama_memory_params { // kv cache ggml_type type_k; @@ -22,7 +27,7 @@ public: virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0; virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0; virtual void seq_keep(llama_seq_id seq_id) = 0; - virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0; + virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0; virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0; virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0; @@ -30,3 +35,42 @@ public: virtual bool get_can_edit() const = 0; }; + +enum llama_memory_status { + LLAMA_MEMORY_STATUS_SUCCESS = 0, + LLAMA_MEMORY_STATUS_FAILED_PREPARE, + LLAMA_MEMORY_STATUS_FAILED_COMPUTE, +}; + +// the interface for managing the memory state during batch processing +// this interface is implemented per memory type. see: +// - llama_kv_cache_unified_state +// - llama_kv_cache_unified_iswa_state +// ... +// +// the only method that can mutate the memory and the memory state is llama_memory_i::apply() +// +// TODO: rename to llama_memory_context_i ? +class llama_memory_state_i { +public: + virtual ~llama_memory_state_i() = default; + + // consume the current ubatch from the state and proceed to the next one + // return false if we are done + virtual bool next() = 0; + + // apply the memory state for the current ubatch to the memory object + // return false on failure + virtual bool apply() = 0; + + // TODO: this might get reworked in the future when refactoring llama_batch + virtual std::vector & out_ids() = 0; + + // get the current ubatch + virtual const llama_ubatch & get_ubatch() const = 0; + + // get the status of the memory state + virtual llama_memory_status get_status() const = 0; +}; + +using llama_memory_state_ptr = std::unique_ptr; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b0512bf44..e46403f3c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -5,7 +5,10 @@ #include "llama-batch.h" #include "llama-cparams.h" #include "llama-model-loader.h" -#include "llama-kv-cache.h" + +#include "llama-kv-cache-unified.h" +#include "llama-kv-cache-unified-iswa.h" +#include "llama-kv-cache-recurrent.h" #include "ggml-cpp.h" @@ -464,11 +467,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { GGML_ASSERT(hparams.n_expert_used == 0); } - // zero-out the array hparams std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); + std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0); + + std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0); + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); @@ -575,7 +581,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; hparams.n_swa = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick - hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full + hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full switch (hparams.n_expert) { case 16: type = LLM_TYPE_17B_16E; break; @@ -681,6 +687,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); + ml.get_arr_n(LLM_KV_CLASSIFIER_OUTPUT_LABELS, hparams.n_cls_out, false); switch (hparams.n_layer) { case 3: @@ -876,7 +883,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_NONE; hparams.n_swa = 0; - hparams.n_swa_pattern = 1; + hparams.set_swa_pattern(1); } } break; case LLM_ARCH_PHIMOE: @@ -948,7 +955,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.n_swa = 4096; // default value of gemma 2 - hparams.n_swa_pattern = 2; + hparams.set_swa_pattern(2); hparams.attn_soft_cap = true; ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); @@ -966,7 +973,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_GEMMA3: { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.n_swa_pattern = 6; + hparams.set_swa_pattern(6); hparams.rope_freq_base_train_swa = 10000.0f; hparams.rope_freq_scale_train_swa = 1.0f; @@ -1051,7 +1058,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_COHERE2: { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.n_swa_pattern = 4; + hparams.set_swa_pattern(4); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); @@ -2123,7 +2130,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_NOMIC_BERT_MOE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); if (arch == LLM_ARCH_BERT) { pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); @@ -2131,8 +2138,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); - cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED); - cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {1}, TENSOR_NOT_REQUIRED); + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); } tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); @@ -2141,7 +2148,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - if (arch == LLM_ARCH_BERT) { + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + + if (!layer.wqkv) { layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); @@ -2150,12 +2160,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - } else { - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - } - - if (arch == LLM_ARCH_NOMIC_BERT_MOE) { - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); } layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); @@ -2529,7 +2533,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -4363,7 +4371,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); - LLAMA_LOG_INFO("%s: n_swa_pattern = %u\n", __func__, hparams.n_swa_pattern); + LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any()); LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k); LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v); LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); @@ -5923,8 +5931,10 @@ struct llm_build_bert : public llm_graph_context { inpL = build_inp_embd(model.tok_embd); // token types are hardcoded to zero ("Sentence A") - ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0); - inpL = ggml_add(ctx0, inpL, type_row0); + if (model.type_embd) { + ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0); + inpL = ggml_add(ctx0, inpL, type_row0); + } if (model.arch == LLM_ARCH_BERT) { inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL); } @@ -5945,36 +5955,11 @@ struct llm_build_bert : public llm_graph_context { ggml_tensor * Vcur; // self-attention - if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) { - Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq); - - if (model.layers[il].attn_q_norm) { - Qcur = build_norm(Qcur, - model.layers[il].attn_q_norm, - model.layers[il].attn_q_norm_b, - LLM_NORM, il); - } - - Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk); - - if (model.layers[il].attn_k_norm) { - Kcur = build_norm(Kcur, - model.layers[il].attn_k_norm, - model.layers[il].attn_k_norm_b, - LLM_NORM, il); - } - - Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - } else { - // compute Q and K and RoPE them + if (model.layers[il].wqkv) { cur = build_lora_mm(model.layers[il].wqkv, cur); cb(cur, "wqkv", il); - if (model.arch == LLM_ARCH_NOMIC_BERT_MOE || model.arch == LLM_ARCH_JINA_BERT_V3) { + if (model.layers[il].bqkv) { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); } @@ -5982,11 +5967,32 @@ struct llm_build_bert : public llm_graph_context { Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + } else { + Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq); + Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk); + Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv); + } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, + model.layers[il].attn_q_norm, + model.layers[il].attn_q_norm_b, + LLM_NORM, il); + } + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, + model.layers[il].attn_k_norm, + model.layers[il].attn_k_norm_b, + LLM_NORM, il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // RoPE + if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) { Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -8932,9 +8938,9 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - const auto kv_head = kv_self->head; + const auto kv_head = kv_state->get_head(); const int64_t d_conv = hparams.ssm_d_conv; const int64_t d_inner = hparams.ssm_d_inner; @@ -8952,8 +8958,8 @@ struct llm_build_mamba : public llm_graph_context { GGML_ASSERT(ubatch.equal_seqs); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); - ggml_tensor * conv_states_all = kv_self->k_l[il]; - ggml_tensor * ssm_states_all = kv_self->v_l[il]; + ggml_tensor * conv_states_all = kv_state->get_k_l(il); + ggml_tensor * ssm_states_all = kv_state->get_v_l(il); // (ab)using the KV cache to store the states ggml_tensor * conv = build_copy_mask_state( @@ -11680,7 +11686,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); const auto n_tokens = ubatch.n_tokens; const auto n_seqs = ubatch.n_seqs; @@ -11690,7 +11696,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { const auto n_head = n_embd / head_size; const auto n_head_kv = hparams.n_head_kv(il); - const auto kv_head = kv_self->head; + const auto kv_head = kv_state->get_head(); const auto & layer = model.layers[il]; @@ -11802,7 +11808,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { } ggml_tensor * wkv_state = build_copy_mask_state( - gf, kv_self->v_l[il], state_copy, state_mask, + gf, kv_state->get_v_l(il), state_copy, state_mask, hparams.n_embd_v_s(), n_seqs); ggml_tensor * wkv_output; @@ -11821,9 +11827,9 @@ struct llm_build_rwkv6_base : public llm_graph_context { wkv_state, ggml_view_1d( ctx0, - kv_self->v_l[il], + kv_state->get_v_l(il), hparams.n_embd_v_s() * n_seqs, - hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il]) + hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il)) ) ) ); @@ -12076,7 +12082,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { ggml_tensor *& first_layer_value, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); const auto n_tokens = ubatch.n_tokens; const auto n_seqs = ubatch.n_seqs; @@ -12085,7 +12091,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { const auto head_count = n_embd / head_size; const auto n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = kv_self->head; + const auto kv_head = kv_state->get_head(); const auto & layer = model.layers[il]; @@ -12156,7 +12162,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens); ggml_tensor * wkv_state = build_copy_mask_state( - gf, kv_self->v_l[il], state_copy, state_mask, + gf, kv_state->get_v_l(il), state_copy, state_mask, hparams.n_embd_v_s(), n_seqs); ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state); @@ -12170,9 +12176,9 @@ struct llm_build_rwkv7_base : public llm_graph_context { wkv_state, ggml_view_1d( ctx0, - kv_self->v_l[il], + kv_state->get_v_l(il), hparams.n_embd_v_s() * n_seqs, - hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il]) + hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il)) ) ) ); @@ -13233,6 +13239,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_JINA_BERT_V3: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: + case LLM_ARCH_WAVTOKENIZER_DEC: { res = nullptr; } break; @@ -13259,7 +13266,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - GGML_ASSERT(hparams.n_swa_pattern != 1); + GGML_ASSERT(hparams.is_swa_any()); res = new llama_kv_cache_unified_iswa( *this, @@ -13270,10 +13277,10 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, params.swa_full, cparams.n_ctx, cparams.n_seq_max, - cparams.n_batch, + cparams.n_ubatch, padding); } else { - GGML_ASSERT(hparams.n_swa_pattern == 1); + GGML_ASSERT(!hparams.is_swa_any()); res = new llama_kv_cache_unified( *this, @@ -13302,7 +13309,6 @@ llm_graph_result_ptr llama_model::build_graph( switch (arch) { case LLM_ARCH_LLAMA: - case LLM_ARCH_MINICPM: { llm = std::make_unique(*this, params, gf); } break; @@ -13544,6 +13550,7 @@ llm_graph_result_ptr llama_model::build_graph( } break; case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_MINICPM: { llm = std::make_unique(*this, params, gf); } break; @@ -13634,6 +13641,10 @@ int32_t llama_model_n_head_kv(const llama_model * model) { return model->hparams.n_head_kv(); } +int32_t llama_model_n_swa(const llama_model * model) { + return model->hparams.n_swa; +} + // deprecated int32_t llama_n_ctx_train(const llama_model * model) { return llama_model_n_ctx_train(model); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 804b11e0a..bfbf5fa23 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -798,7 +798,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d } // if we have enough values the operation was a success - if (filtered_tokens.size() >= ctx->min_keep) { + if (!filtered_tokens.empty() && filtered_tokens.size() >= ctx->min_keep) { memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data)); cur_p->size = filtered_tokens.size(); min_p_applied = true; @@ -909,7 +909,7 @@ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token cum_sum += cur_p->data[idx].p; // Check if the running sum is greater than typical or if we have kept at least min_keep tokens - if (cum_sum > ctx->p && i >= ctx->min_keep - 1) { + if (cum_sum > ctx->p && (ctx->min_keep == 0 || i >= ctx->min_keep - 1)) { last_idx = i + 1; break; } diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 6d6a1c6e1..00d26ff85 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -835,7 +835,7 @@ struct llm_tokenizer_ugm_session { } // initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores - std::vector tokenization_results(input_len + 1, {vocab.token_unk(), 0, -FLT_MAX}); + std::vector tokenization_results(input_len + 1, {vocab.token_unk(), 0, -DBL_MAX}); // at the beginning tokenization score is zero tokenization_results[0] = { vocab.token_unk(), 0, 0 }; @@ -867,7 +867,7 @@ struct llm_tokenizer_ugm_session { const double challenger_score = current_best.score_sum + token_score; struct best_tokenization & current_champ = tokenization_results[prefix_offset]; if (challenger_score > current_champ.score_sum) { - struct best_tokenization challenger = { token_id, input_offset, (float) challenger_score }; + struct best_tokenization challenger = { token_id, input_offset, challenger_score }; current_champ = challenger; } } @@ -881,7 +881,7 @@ struct llm_tokenizer_ugm_session { prefix_offset = input_offset + n_utf8_code_units; struct best_tokenization & current_champ = tokenization_results[prefix_offset]; if (challenger_score > current_champ.score_sum) { - struct best_tokenization challenger = { vocab.token_unk(), input_offset, (float) challenger_score }; + struct best_tokenization challenger = { vocab.token_unk(), input_offset, challenger_score }; current_champ = challenger; } } @@ -1007,7 +1007,7 @@ private: struct best_tokenization { llama_token token_id; size_t input_offset; - float score_sum; + double score_sum; }; struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) { @@ -2096,7 +2096,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { // set attributes by model/tokenizer/architecture name if (false || _contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"}) - || _contains_any(general_arch, {"jina-bert-v3"}) + || _contains_any(general_arch, {"nomic-bert-moe", "jina-bert-v3"}) ) { _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true); } else if (_contains_any(model_name, {"phi-3", "phi3"})) { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 083347d18..83f7d1a45 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -97,6 +97,9 @@ llama_test(test-tokenizer-0 NAME test-tokenizer-0-qwen2 ARGS ${CMAKE llama_test(test-tokenizer-0 NAME test-tokenizer-0-refact ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf) llama_test(test-tokenizer-0 NAME test-tokenizer-0-starcoder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf) +# TODO: missing HF tokenizer for this model in convert_hf_to_gguf_update.py, see https://github.com/ggml-org/llama.cpp/pull/13847 +# llama_test(test-tokenizer-0 NAME test-tokenizer-0-nomic-bert-moe ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-nomic-bert-moe.gguf) + if (LLAMA_LLGUIDANCE) llama_build_and_test(test-grammar-llguidance.cpp ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf) endif () @@ -142,8 +145,10 @@ if (NOT WIN32) # llama_build_and_test(test-double-float.cpp) # SLOW endif() -llama_build_and_test(test-log.cpp) +llama_build_and_test(test-chat-parser.cpp) llama_build_and_test(test-chat-template.cpp) +llama_build_and_test(test-json-partial.cpp) +llama_build_and_test(test-log.cpp) llama_build_and_test(test-regex-partial.cpp) # this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135) diff --git a/tests/test-chat-parser.cpp b/tests/test-chat-parser.cpp new file mode 100644 index 000000000..59e44e07d --- /dev/null +++ b/tests/test-chat-parser.cpp @@ -0,0 +1,352 @@ +// Tests chat handling, including grammar generation and parsing for tool calling, for various templates. +// +// Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates, +// e.g. given Minja (http://github.com/google/minja) checked out in parent dir: +// +// cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null +// +#include +#include +#include + +#include "chat-parser.h" +#include "common.h" +#include "log.h" +#include "regex-partial.h" + +template +static void assert_equals(const T & expected, const T & actual) { + if (expected != actual) { + std::cerr << "Expected: " << expected << std::endl; + std::cerr << "Actual: " << actual << std::endl; + std::cerr << std::flush; + throw std::runtime_error("Test failed"); + } +} +static void assert_equals(const char * expected, const std::string & actual) { + return assert_equals(expected, actual); +} + +static void assert_throws(const std::function & fn, const std::string & expected_exception_pattern = "") { + try { + fn(); + } catch (const std::exception & e) { + if (expected_exception_pattern.empty()) { + return; + } + std::regex expected_exception_regex(expected_exception_pattern); + std::string actual_message = e.what(); + if (std::regex_search(actual_message, expected_exception_regex)) { + return; + } + throw std::runtime_error("Exception doesn't match expected pattern: " + actual_message + " (pattern: " + expected_exception_pattern + ")"); + throw std::runtime_error("Exception of unexpected type: " + std::string(e.what())); + } + throw std::runtime_error("Exception was expected but not thrown"); +} + +static void test_reasoning() { + { + common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, { + /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ false, + }); + assert_equals(false, builder.try_parse_reasoning("", "")); + assert_equals("CogitoErgo sum", builder.consume_rest()); + } + { + common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, { + /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ false, + }); + assert_equals(true, builder.try_parse_reasoning("", "")); + assert_equals(std::string("Cogito"), builder.result().reasoning_content); + assert_equals("Ergo sum", builder.consume_rest()); + } + { + common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, { + /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ false, + }); + assert_equals(false, builder.try_parse_reasoning("", "")); + assert_equals("CogitoErgo sum", builder.consume_rest()); + } + { + common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, { + /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ true, + }); + assert_equals(true, builder.try_parse_reasoning("", "")); + assert_equals(std::string("Cogito"), builder.result().reasoning_content); + assert_equals("Ergo sum", builder.consume_rest()); + } + { + common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, { + /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ true, + /* .thinking_forced_open = */ true, + }); + assert_equals(true, builder.try_parse_reasoning("", "")); + assert_equals("Cogito", builder.result().content); + assert_equals("Ergo sum", builder.consume_rest()); + } +} + +static void test_regex() { + auto test_throws = [](const std::string & input, const std::string & regex, const std::string & expected_exception_pattern = "") { + common_chat_msg_parser builder(input, /* is_partial= */ false, {}); + assert_throws([&]() { builder.consume_regex(common_regex(regex)); }, expected_exception_pattern); + }; + + test_throws("Hello, world!", "abc", "^abc$"); + test_throws("Hello, world!", "e", "^e$"); + + { + common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {}); + builder.consume_regex(common_regex("Hello")); + assert_equals(", world!", builder.consume_rest()); + } + + { + // When in non partial mode, we can say whether the regex was consumed or not. + common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {}); + assert_equals(false, builder.try_consume_regex(common_regex("Hello, world!")).has_value()); + } + { + common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {}); + auto res = builder.try_consume_regex(common_regex("H(el)l(?:o, world!)?")); + assert_equals(true, res.has_value()); + // Verify captures + assert_equals(2, res->groups.size()); + assert_equals("Hell", builder.str(res->groups[0])); + assert_equals("el", builder.str(res->groups[1])); + // Verify position is after the match + assert_equals(4, builder.pos()); + assert_equals("o,", builder.consume_rest()); + } + { + // But in partial mode, we have a partial final match / can't decide, so we throw a partial exception. + common_chat_msg_parser builder("Hello,", /* is_partial= */ true, {}); + assert_throws([&]() { + builder.try_consume_regex(common_regex("Hello, world!")); + }, "^Hello, world!$"); + } + + // Now regardless of the mode, we can tell these aren't a match. + for (const auto is_partial : {false, true}) { + common_chat_msg_parser builder("Hello,", is_partial, {}); + assert_equals(false, builder.try_consume_regex(common_regex("a(b|c)(d|e)f")).has_value()); + } + for (const auto is_partial : {false, true}) { + common_chat_msg_parser builder("Hello,", is_partial, {}); + assert_equals(false, builder.try_consume_literal("Oh")); + } +} + +const std::vector barely_healable_jsons = { + "{", + "{\"", + "{\"\\", + "{\"n", + "{\"name\"", + "{\"name\":", + "{\"name\":\"", + "{\"name\":\"\\", + "{\"name\":\"python", + "{\"name\":\"python\\", + "{\",", + "{\":", + "{\"[", + "{\"]", + "{\"{", + "{\"}", + "{\"1", + "{\"name\":\",", + "{\"name\":\":", + "{\"name\":\"[", + "{\"name\":\"]", + "{\"name\":\"{", + "{\"name\":\"}", + "{\"name\":\"1", +}; + +static void test(const std::string & input, bool is_partial, const std::vector> & args_paths, const std::vector> & content_paths, const std::string & expected) { + common_chat_msg_parser builder(input, is_partial, {}); + auto js = builder.try_consume_json_with_dumped_args(args_paths, content_paths); + assert_equals(true, js.has_value()); + assert_equals(is_partial, js->is_partial); + assert_equals(expected, args_paths.size() == 1 && args_paths[0].empty() ? js->value.get() : js->value.dump()); +} +static void test_with_args(const std::string & input, const std::string & expected, bool parse_as_partial = true, bool is_partial = true) { + common_chat_msg_parser builder(input, parse_as_partial, {}); + auto js = builder.try_consume_json_with_dumped_args({{"args"}}, {}); + assert_equals(true, js.has_value()); + assert_equals(is_partial, js->is_partial); + assert_equals(expected, js->value.dump()); +} + +static void test_json_with_dumped_args_no_args() { + // Normal JSON, nothing to heal, nothing to dump + test("{\"name\": \"python\"}", false, {}, {}, "{\"name\":\"python\"}"); + // Full json is args + test("{\"name\": \"python\"}", false, {{}}, {}, "{\"name\":\"python\"}"); + + // If the arguments are further down, don't heal partial content. + for (const auto & src : barely_healable_jsons) { + test(src, true, {{"arguments"}}, {}, "{}"); + } + // But heal content that isn't partial. + test("{\"name\": \"python\"", true, {{"arguments"}}, {}, "{\"name\":\"python\"}"); +} + +static void test_json_with_dumped_args() { + + // Partial content. + test("{\"content\": \"t", true, {}, {{"content"}}, "{\"content\":\"t\"}"); + test("{\"content\": \"", true, {}, {{"content"}}, "{\"content\":\"\"}"); + test("{\"content\": ", true, {}, {{"content"}}, "{}"); + + // If the entire JSON is the arguments, healing it them dumping it produces the same output as the input (just reformatted). + test("{\"name\": \"python", true, {{}}, {}, "{\"name\":\"python"); + for (const auto & src : barely_healable_jsons) { + test(src, true, {{}}, {}, src); + } + + // Full JSON w/ args + for (auto parse_as_partial : {true, false}) { + test_with_args( + R"({"name": "python", "args": {"arg1": 1}})", + R"({"name":"python","args":"{\"arg1\":1}"})", + parse_as_partial, + /* is_partial= */ false + ); + } + + // Partial JSON w/ partial args + test_with_args( + R"({"foo": "bar", "args": {")", + R"({"foo":"bar","args":"{\""})" + ); + // Partial args broken in object key + test_with_args( + R"({"foo": "bar", "args": {"ar)", + R"({"foo":"bar","args":"{\"ar"})" + ); + // Partial args broken after object key + test_with_args( + R"({"foo": "bar", "args": {"arg1")", + R"({"foo":"bar","args":"{\"arg1\""})" + ); + // Partial args broken before object value + test_with_args( + R"({"foo": "bar", "args": {"arg1":)", + R"({"foo":"bar","args":"{\"arg1\":"})" + ); + // Partial args broken before object value (space) + test_with_args( + R"({"foo": "bar", "args": {"arg1": )", + R"({"foo":"bar","args":"{\"arg1\":"})" + ); + // Partial args broken in object value that may not be complete (int) + test_with_args( + R"({"foo": "bar", "args": {"arg1": 1)", + R"({"foo":"bar","args":"{\"arg1\":"})" + ); + // Partial args broken in object value that is complete (int) + test_with_args( + R"({"foo": "bar", "args": {"arg1": 1 )", + R"({"foo":"bar","args":"{\"arg1\":1"})" + ); + // Partial args broken in object value that is incomplete (string) + test_with_args( + R"({"foo": "bar", "args": {"arg1": ")", + R"({"foo":"bar","args":"{\"arg1\":\""})" + ); + // Partial args broken in object value that is complete (string) + test_with_args( + R"({"foo": "bar", "args": {"arg1": "1")", + R"({"foo":"bar","args":"{\"arg1\":\"1\""})" + ); + // Partial args broken on array opening + test_with_args( + R"({"foo": "bar", "args": [)", + R"({"foo":"bar","args":"["})" + ); + // Partial args broken on array value that is incomplete (int) + test_with_args( + R"({"foo": "bar", "args": [1)", + R"({"foo":"bar","args":"["})" + ); + // Partial args broken on array value that is complete (int) + test_with_args( + R"({"foo": "bar", "args": [1 )", + R"({"foo":"bar","args":"[1"})" + ); + // Partial args broken on array value that is complete (string) + test_with_args( + R"({"foo": "bar", "args": ["1")", + R"({"foo":"bar","args":"[\"1\""})" + ); + // Partial args broken after array value + test_with_args( + R"({"foo": "bar", "args": [1,)", + R"({"foo":"bar","args":"[1,"})" + ); + // Partial args broken on nested array + test_with_args( + R"({"foo": "bar", "args": {"arg1": [)", + R"({"foo":"bar","args":"{\"arg1\":["})" + ); +} + +static void test_positions() { + { + common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {}); + assert_equals(0, builder.pos()); + assert_throws([&]() { builder.move_to(100); }); + assert_equals(0, builder.pos()); + assert_throws([&]() { builder.move_back(1); }); + assert_equals(0, builder.pos()); + + builder.move_to(8); + assert_equals(8, builder.pos()); + builder.move_back(1); + assert_equals(7, builder.pos()); + assert_equals("world!", builder.consume_rest()); + + builder.move_to(0); + assert_equals(0, builder.pos()); + + assert_throws([&]() { builder.finish(); }); + assert_equals(0, builder.pos()); + + builder.move_to(builder.input().size()); + builder.finish(); + } + { + common_chat_msg_parser builder("Hello, world!", /* is_partial= */ true, {}); + + builder.move_to(builder.input().size()); + assert_equals(builder.input().size(), builder.pos()); + builder.finish(); + } +} + +int main() { + test_positions(); + test_json_with_dumped_args_no_args(); + test_json_with_dumped_args(); + test_reasoning(); + test_regex(); + std::cout << "All tests passed!\n"; + return 0; +} diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 4d70da8c3..1c9807921 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -5,21 +5,80 @@ // // cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null // -#include -#include -#include -#include - #include "chat.h" #include "../src/unicode.h" #include "../src/llama-grammar.h" +#include + +#include +#include +#include + using json = nlohmann::ordered_json; +static std::ostream & operator<<(std::ostream & os, const common_chat_msg_diff & diff) { + // os << "reasoning_content_delta: " << diff.reasoning_content_delta << '\n'; + os << "{ content_delta: " << diff.content_delta << "; "; + if (diff.tool_call_index != std::string::npos) { + os << "tool_call_index: " << diff.tool_call_index << "; "; + os << "tool_call_delta.name: " << diff.tool_call_delta.name << "; "; + os << "tool_call_delta.id: " << diff.tool_call_delta.id << "; "; + os << "tool_call_delta.arguments: " << diff.tool_call_delta.arguments << "; "; + } + os << "}"; + return os; +} +// operator<< for vector: +static std::ostream & operator<<(std::ostream & os, const std::vector & diffs) { + os << "[\n"; + for (const auto & diff : diffs) { + os << " " << diff << ",\n"; + } + os << "]"; + return os; +} +static std::ostream & operator<<(std::ostream & os, const common_chat_msg & msg) { + os << "{ role: " << msg.role << "; "; + os << "content: " << msg.content << "; "; + os << "content_parts: [\n"; + for (const auto & part : msg.content_parts) { + os << " { type: " << part.type << "; text: " << part.text << " },\n"; + } + os << "]; "; + os << "reasoning_content: " << msg.reasoning_content << "; "; + os << "tool_calls: [\n"; + for (const auto & tool_call : msg.tool_calls) { + os << " { name: " << tool_call.name << "; arguments: " << tool_call.arguments << "; id: " << tool_call.id << " },\n"; + } + os << "]"; + os << "}"; + return os; +} + +template static bool equals(const T & expected, const T & actual) { + return expected == actual; +} + +static common_chat_msg normalize(const common_chat_msg & msg) { + common_chat_msg normalized = msg; + for (auto & tool_call : normalized.tool_calls) { + try { + tool_call.arguments = json::parse(tool_call.arguments).dump(); + } catch (const std::exception &) { + // Do nothing + } + } + return normalized; +} +template <> +bool equals(const common_chat_msg & expected, const common_chat_msg & actual) { + return normalize(expected) == normalize(actual); +} template static void assert_equals(const T & expected, const T & actual) { - if (expected != actual) { + if (!equals(expected, actual)) { std::cerr << "Expected: " << expected << std::endl; std::cerr << "Actual: " << actual << std::endl; std::cerr << std::flush; @@ -77,6 +136,15 @@ static bool match_string(const std::string & input, llama_grammar * grammar) { return false; } +static std::string renormalize_json(const std::string & json_str) { + try { + auto json_obj = json::parse(json_str); + return json_obj.dump(); + } catch (const std::exception & e) { + std::cerr << "Failed to parse JSON: " << e.what() << '\n'; + return json_str; + } +} static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) { assert_equals(expected.role, actual.role); assert_equals(expected.content, actual.content); @@ -93,7 +161,7 @@ static void assert_msg_equals(const common_chat_msg & expected, const common_cha const auto & expected_tool_call = expected.tool_calls[i]; const auto & actual_tool_call = actual.tool_calls[i]; assert_equals(expected_tool_call.name, actual_tool_call.name); - assert_equals(json::parse(expected_tool_call.arguments).dump(), json::parse(actual_tool_call.arguments).dump()); + assert_equals(renormalize_json(expected_tool_call.arguments), renormalize_json(actual_tool_call.arguments)); assert_equals(expected_tool_call.id, actual_tool_call.id); } } @@ -152,14 +220,12 @@ static delta_data init_delta(const struct common_chat_templates * tmpls, const s const common_chat_msg & user_message, const common_chat_msg & delta_message, const std::vector & tools, - const common_chat_tool_choice & tool_choice, - bool think = false) { + const common_chat_tool_choice & tool_choice) { common_chat_templates_inputs inputs; inputs.parallel_tool_calls = true; inputs.messages.push_back(user_message); inputs.tools = tools; inputs.tool_choice = tool_choice; - inputs.extract_reasoning = think; auto params_prefix = common_chat_templates_apply(tmpls, inputs); inputs.messages.push_back(delta_message); @@ -211,19 +277,22 @@ static void test_templates(const struct common_chat_templates * tmpls, const std const std::string & expected_delta = "", bool expect_grammar_triggered = true, bool test_grammar_if_triggered = true, - bool think = false) { + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE) { common_chat_msg user_message; user_message.role = "user"; user_message.content = "Hello, world!"; for (const auto & tool_choice : std::vector {COMMON_CHAT_TOOL_CHOICE_AUTO, COMMON_CHAT_TOOL_CHOICE_REQUIRED}) { - auto data = init_delta(tmpls, end_tokens, user_message, test_message, tools, tool_choice, think); + auto data = init_delta(tmpls, end_tokens, user_message, test_message, tools, tool_choice); if (!expected_delta.empty()) { assert_equals(expected_delta, data.delta); } if (expect_grammar_triggered) { - const auto msg = common_chat_parse(data.delta, data.params.format); + common_chat_syntax syntax; + syntax.format = data.params.format; + syntax.reasoning_format = reasoning_format; + const auto msg = common_chat_parse(data.delta, /* is_partial= */ false, syntax); assert_msg_equals(test_message, msg); } @@ -251,15 +320,25 @@ static void test_templates(const struct common_chat_templates * tmpls, const std { const auto & pattern = trigger.value; if (std::regex_search(constrained, match, std::regex(pattern))) { - pos = match.position(); + pos = match.position(1); } break; } - case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START: + case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL: { const auto & pattern = trigger.value; - if (std::regex_search(constrained, match, std::regex(pattern)) && match.position() == 0) { - pos = 0; + if (std::regex_match(constrained, match, std::regex(pattern))) { + auto mpos = std::string::npos; + for (size_t i = 1; i < match.size(); ++i) { + if (match[i].length() > 0) { + mpos = match.position(i); + break; + } + } + if (mpos == std::string::npos) { + mpos = match.position(0); + } + pos = mpos; } break; } @@ -313,117 +392,42 @@ const common_chat_msg message_user_parts { /* .tool_name = */ "", /* .tool_call_id = */ "", }; -const common_chat_msg message_assist { - "assistant", - "Hello, world!\nWhat's up?", - /* .content_parts = */ {}, - /* .tool_calls = */ {}, - /* .reasoning_content = */ "", - /* .tool_name = */ "", - /* .tool_call_id = */ "", -}; -const common_chat_msg message_assist_thoughts_unparsed_think { - "assistant", - "I'm thinkingHello, world!\nWhat's up?", - /* .content_parts = */ {}, - /* .tool_calls = */ {}, - /* .reasoning_content = */ "", - /* .tool_name = */ "", - /* .tool_call_id = */ "", -}; -const common_chat_msg message_assist_thoughts_unparsed_r7b { - "assistant", - "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?", - /* .content_parts = */ {}, - /* .tool_calls = */ {}, - /* .reasoning_content = */ "", - /* .tool_name = */ "", - /* .tool_call_id = */ "", -}; -const common_chat_msg message_assist_thoughts { - "assistant", - "Hello, world!\nWhat's up?", - /* .content_parts = */ {}, - /* .tool_calls = */ {}, - /* .reasoning_content = */ "I'm thinking", - /* .tool_name = */ "", - /* .tool_call_id = */ "", -}; -const std::vector tool_calls { - { "special_function", "{\"arg1\": 1}", /* .id = */ "" }, -}; -const std::vector tool_calls_idx { - { "special_function", "{\"arg1\": 1}", /* .id = */ "0" }, -}; -const std::vector tool_calls_id { - { "special_function", "{\"arg1\": 1}", /* .id = */ "123456789" }, -}; +static common_chat_msg simple_assist_msg(const std::string & content, const std::string & reasoning_content = "", const std::string & tool_name = "", const std::string & arguments = "", const std::string & id = "") { + common_chat_msg msg; + msg.role = "assistant"; + msg.content = content; + msg.reasoning_content = reasoning_content; + if (!tool_name.empty()) { + msg.tool_calls.push_back({ tool_name, arguments, id }); + } + return msg; +} +const common_chat_msg message_assist = simple_assist_msg("Hello, world!\nWhat's up?"); +const common_chat_msg message_assist_empty = simple_assist_msg(""); +const common_chat_msg message_assist_thoughts_unparsed_deepseek = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?"); +const common_chat_msg message_assist_thoughts_unparsed_md = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}```"); +const common_chat_msg message_assist_thoughts_unparsed_md_partial = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}"); -const common_chat_msg message_assist_call { - "assistant", - "", - /* .content_parts = */ {}, - tool_calls, - /* .reasoning_content = */ "", - /* .tool_name = */ "", - /* .tool_call_id = */ "", -}; -const common_chat_msg message_assist_call_thoughts = { - "assistant", - /* .content = */ "", - /* .content_parts = */ {}, - tool_calls, - /* .reasoning_content = */ "I'm\nthinking", - /* .tool_name = */ "", - /* .tool_call_id = */ "", -}; -const common_chat_msg message_assist_call_thoughts_unparsed = { - "assistant", - /* .content = */ "I'm\nthinking", - /* .content_parts = */ {}, - tool_calls, - /* .reasoning_content = */ "", - /* .tool_name = */ "", - /* .tool_call_id = */ "", -}; -const common_chat_msg message_assist_call_id { - "assistant", - "", - /* .content_parts = */ {}, - tool_calls_id, - /* .reasoning_content = */ "", - /* .tool_name = */ "", - /* .tool_call_id = */ "", -}; -const common_chat_msg message_assist_call_idx { - "assistant", - "", - /* .content_parts = */ {}, - tool_calls_idx, - /* .reasoning_content = */ "", - /* .tool_name = */ "", - /* .tool_call_id = */ "", -}; -const common_chat_msg message_assist_call_python { - "assistant", - "", - /* .content_parts = */ {}, - { { "python", "{\"code\": \"print('hey')\"}", /* .id = */ "" } }, - /* .reasoning_content = */ "", - /* .tool_name = */ "", - /* .tool_call_id = */ "", -}; -const common_chat_msg message_assist_call_code_interpreter { - "assistant", - "", - /* .content_parts = */ {}, - { { "code_interpreter", "{\"code\": \"print('hey')\"}", /* .id = */ "" } }, - /* .reasoning_content = */ "", - /* .tool_name = */ "", - /* .tool_call_id = */ "", -}; +const common_chat_msg message_assist_thoughts_unparsed_r7b = simple_assist_msg("<|START_THINKING|>I'm\nthinking<|END_THINKING|>Hello, world!\nWhat's up?"); +const common_chat_msg message_assist_thoughts = simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking"); +const common_chat_msg message_assist_thoughts_unopened_unparsed = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?"); +const common_chat_msg message_assist_thoughts_no_content = simple_assist_msg("", "I'm\nthinking"); +const common_chat_msg message_assist_call = simple_assist_msg("", "", "special_function", "{\"arg1\": 1}"); +const common_chat_msg message_assist_call_content = simple_assist_msg("Hello, world!\nWhat's up?", "", "special_function", "{\"arg1\":1}"); +const common_chat_msg message_assist_call_empty_args = simple_assist_msg("", "", "special_function"); +const common_chat_msg message_assist_call_cutoff_args = simple_assist_msg("", "", "special_function", "{\"arg"); +const common_chat_msg message_assist_call_thoughts = simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\":1}"); +const common_chat_msg message_assist_call_thoughts_unparsed = simple_assist_msg("I'm\nthinking\n\n", "", "special_function", "{\"arg1\": 1}"); +const common_chat_msg message_assist_call_id = simple_assist_msg("", "", "special_function", "{\"arg1\":1}", /* .id = */ "123456789"); +const common_chat_msg message_assist_call_idx = simple_assist_msg("", "", "special_function", "{\"arg1\":1}", /* .id = */ "0"); +const common_chat_msg message_assist_thoughts_call_idx = simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\": 1}", /* id = */ "0"); +const common_chat_msg message_assist_call_python = simple_assist_msg("", "", "python", "{\"code\":\"print('hey')\"}"); +const common_chat_msg message_assist_call_python_lines = simple_assist_msg("", "", "python", "{\"code\":\"# This is a program:\\nprint('hey')\"}"); +const common_chat_msg message_assist_call_python_lines_unclosed = simple_assist_msg("", "", "python", "{\"code\":\"# This is a program:\\nprint('hey')"); +const common_chat_msg message_assist_call_code_interpreter = simple_assist_msg("", "", "code_interpreter", "{\"code\":\"print('hey')\"}"); static void test_msgs_oaicompat_json_conversion() { + printf("[%s]\n", __func__); std::vector msgs{ message_user, message_user_parts, @@ -473,7 +477,7 @@ static void test_msgs_oaicompat_json_conversion() { " \"type\": \"function\",\n" " \"function\": {\n" " \"name\": \"python\",\n" - " \"arguments\": \"{\\\"code\\\": \\\"print('hey')\\\"}\"\n" + " \"arguments\": \"{\\\"code\\\":\\\"print('hey')\\\"}\"\n" " }\n" " }\n" " ]\n" @@ -499,6 +503,7 @@ static void test_msgs_oaicompat_json_conversion() { } static void test_tools_oaicompat_json_conversion() { + printf("[%s]\n", __func__); std::vector tools{ special_function_tool, python_tool, @@ -543,29 +548,18 @@ static void test_tools_oaicompat_json_conversion() { } static void test_template_output_parsers() { + printf("[%s]\n", __func__); common_chat_templates_inputs inputs_no_tools; inputs_no_tools.messages = {message_user}; - inputs_no_tools.extract_reasoning = false; - - common_chat_templates_inputs inputs_no_tools_think; - inputs_no_tools_think.messages = {message_user}; - inputs_no_tools_think.extract_reasoning = true; common_chat_templates_inputs inputs_tools; inputs_tools.messages = {message_user}; inputs_tools.tools = {special_function_tool}; - inputs_tools.extract_reasoning = false; - - common_chat_templates_inputs inputs_tools_think; - inputs_tools_think.messages = {message_user}; - inputs_tools_think.tools = {special_function_tool}; - inputs_tools_think.extract_reasoning = true; common_chat_templates_inputs inputs_tools_builtin; inputs_tools_builtin.messages = {message_user}; inputs_tools_builtin.tools = {python_tool}; - inputs_tools_builtin.extract_reasoning = false; { // Not supported yet @@ -577,44 +571,87 @@ static void test_template_output_parsers() { auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"); std::vector end_tokens{ "<|END_OF_TURN_TOKEN|>" }; - assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format); + for (const auto & inputs : { inputs_no_tools, inputs_tools }) { + auto params = common_chat_templates_apply(tmpls.get(), inputs); + assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, params.format); + assert_equals(false, params.thinking_forced_open); + } assert_msg_equals(message_assist, common_chat_parse( "Hello, world!\nWhat's up?", - COMMON_CHAT_FORMAT_COMMAND_R7B)); - assert_msg_equals(message_assist, - common_chat_parse( - "Hello, world!\nWhat's up?<|END_RESPONSE|>", - COMMON_CHAT_FORMAT_COMMAND_R7B)); + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_COMMAND_R7B})); assert_msg_equals(message_assist, common_chat_parse( "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", - COMMON_CHAT_FORMAT_COMMAND_R7B)); - assert_msg_equals(message_assist_thoughts_unparsed_r7b, - common_chat_parse( - "<|START_THINKING|>I'm thinking<|END_THINKING|>" - "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", - COMMON_CHAT_FORMAT_COMMAND_R7B)); - assert_msg_equals(message_assist_thoughts_unparsed_r7b, - common_chat_parse( - "<|START_THINKING|>I'm thinking<|END_THINKING|>" - "Hello, world!\nWhat's up?<|END_RESPONSE|>", - COMMON_CHAT_FORMAT_COMMAND_R7B)); - + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_COMMAND_R7B})); assert_msg_equals(message_assist_thoughts, common_chat_parse( - "<|START_THINKING|>I'm thinking<|END_THINKING|>" + "<|START_THINKING|>I'm\nthinking<|END_THINKING|>" "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", - COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING)); + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + })); + assert_msg_equals(message_assist_thoughts_unparsed_deepseek, + common_chat_parse( + "<|START_THINKING|>I'm\nthinking<|END_THINKING|>" + "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ true, + /* .thinking_forced_open = */ false, + })); + assert_msg_equals(message_assist_thoughts_unparsed_r7b, + common_chat_parse( + "<|START_THINKING|>I'm\nthinking<|END_THINKING|>" + "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_COMMAND_R7B})); + assert_msg_equals(message_assist_thoughts, + common_chat_parse( + "<|START_THINKING|>I'm\nthinking<|END_THINKING|>" + "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + })); + assert_msg_equals(message_assist_thoughts_call_idx, + common_chat_parse( + "<|START_THINKING|>I'm\nthinking<|END_THINKING|>" + "<|START_ACTION|>[\n" + " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n" + "]<|END_ACTION|>", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + })); + assert_msg_equals(message_assist_thoughts_no_content, + common_chat_parse( + "<|START_THINKING|>I'm\nthinking<|END_THINKING|>" + "<|START_ACTION|>[\n" + " {\"tool_call_id\": \"0\", \"tool_name\": \"special", + /* is_partial= */ true, + { + /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + })); test_templates(tmpls.get(), end_tokens, message_assist_call_idx, tools, "<|START_THINKING|><|END_THINKING|>" "<|START_ACTION|>[\n" " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n" - "]<|END_ACTION|>"); + "]<|END_ACTION|>", + /* expect_grammar_triggered= */ true, + /* test_grammar_if_triggered= */ true, + COMMON_REASONING_FORMAT_DEEPSEEK); test_templates(tmpls.get(), end_tokens, message_assist, tools, "<|START_RESPONSE|>Hello, world!\n" "What's up?<|END_RESPONSE|>", @@ -634,11 +671,52 @@ static void test_template_output_parsers() { // Generic tool calls doesn't generate / parse content-only messages symmetrically. + assert_equals( + simple_assist_msg("{ \"tool_call\" : { \"name\" : \"t"), + common_chat_parse( + "{ \"tool_call\" : { \"name\" : \"t", + /* is_partial= */ true, + { + /* .format = */ COMMON_CHAT_FORMAT_GENERIC, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ true, + /* .parse_tool_calls = */ false, + })); + assert_equals( + message_assist_empty, + common_chat_parse( + "{ \"tool_call\" : { \"name\" : \"t", + /* is_partial= */ true, + {COMMON_CHAT_FORMAT_GENERIC})); + + assert_equals( + simple_assist_msg("", "", "puppeteer_screenshot", "{\"name\":\"servethehome_homepage\","), + common_chat_parse( + R"({"tool_call": {"name": "puppeteer_screenshot", "arguments": {"name": "servethehome_homepage",)", + /* is_partial= */ true, + {COMMON_CHAT_FORMAT_GENERIC})); + + assert_equals( + message_assist_call_empty_args, + common_chat_parse( + "{ \"tool_call\" : { \"name\" : \"special_function\"", + /* is_partial= */ true, + {COMMON_CHAT_FORMAT_GENERIC})); + assert_equals( + message_assist_call_cutoff_args, + common_chat_parse( + "{ \"tool_call\" : { \"name\" : \"special_function\", \"arguments\" : { \"arg", + /* is_partial= */ true, + {COMMON_CHAT_FORMAT_GENERIC})); + assert_msg_equals(message_assist, - common_chat_parse("{\n" - " \"response\": \"Hello, world!\\nWhat's up?\"\n" - "}", - common_chat_templates_apply(tmpls.get(), inputs_tools).format)); + common_chat_parse( + "{\n" + " \"response\": \"Hello, world!\\nWhat's up?\"\n" + "}", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_GENERIC})); test_templates(tmpls.get(), end_tokens, message_assist_call_id, tools, "{\n" " \"tool_calls\": [\n" @@ -663,11 +741,18 @@ static void test_template_output_parsers() { tmpls.get(), end_tokens, message_assist_call_id, tools, "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]"); } + { + auto tmpls = read_templates("models/templates/Qwen-QwQ-32B.jinja"); + std::vector end_tokens{ "<|im_end|>" }; + + assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); + assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format); + } { auto tmpls = read_templates("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"); std::vector end_tokens{ "<|im_end|>" }; - assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); + assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format); assert_equals( COMMON_CHAT_FORMAT_HERMES_2_PRO, @@ -683,114 +768,288 @@ static void test_template_output_parsers() { .format); // Test parsing - assert_msg_equals(message_assist_call, common_chat_parse( - "\n" - "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - "", - COMMON_CHAT_FORMAT_HERMES_2_PRO)); - assert_msg_equals(message_assist_call, common_chat_parse( - "{\"arg1\": 1}", - COMMON_CHAT_FORMAT_HERMES_2_PRO)); - assert_msg_equals(message_assist_call, common_chat_parse( - "\n" - "{\"arg1\": 1}\n" - "", - COMMON_CHAT_FORMAT_HERMES_2_PRO)); - assert_msg_equals(message_assist_call, common_chat_parse( - "\n" - " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - "", - COMMON_CHAT_FORMAT_HERMES_2_PRO)); - assert_msg_equals(message_assist_call, common_chat_parse( - "\n" - " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - "", - COMMON_CHAT_FORMAT_HERMES_2_PRO)); - assert_msg_equals(message_assist_call, common_chat_parse( - "\n" - " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - "", - COMMON_CHAT_FORMAT_HERMES_2_PRO)); - assert_msg_equals(message_assist_call, common_chat_parse( - "```xml\n" - "\n" - " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - "\n" - "```", - COMMON_CHAT_FORMAT_HERMES_2_PRO)); - assert_msg_equals(message_assist_call, common_chat_parse( - "```xml\n" - " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - "```", - COMMON_CHAT_FORMAT_HERMES_2_PRO)); - assert_msg_equals(message_assist_call, common_chat_parse( - "```\n" - " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - "```", - COMMON_CHAT_FORMAT_HERMES_2_PRO)); - assert_msg_equals(message_assist_call, common_chat_parse( - "```\n" - "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - "```", - COMMON_CHAT_FORMAT_HERMES_2_PRO)); - assert_msg_equals(message_assist_call, common_chat_parse( - "```json\n" - " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - "```", - COMMON_CHAT_FORMAT_HERMES_2_PRO)); - assert_msg_equals(message_assist_call, common_chat_parse( - "```json\n" - "\n" - " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}} \n" - " \n" - "``` ", - COMMON_CHAT_FORMAT_HERMES_2_PRO)); - assert_msg_equals(message_assist_call, common_chat_parse( - "\n" - " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - "", - COMMON_CHAT_FORMAT_HERMES_2_PRO)); - assert_msg_equals(message_assist_call, common_chat_parse( - "\n" - " {\n" - " \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}\n" - " }\n" - "", - COMMON_CHAT_FORMAT_HERMES_2_PRO)); - assert_msg_equals(message_assist_call, common_chat_parse( - "\n" - " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - "", - COMMON_CHAT_FORMAT_HERMES_2_PRO)); - assert_msg_equals(message_assist_call, common_chat_parse( - "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", - COMMON_CHAT_FORMAT_HERMES_2_PRO)); - assert_msg_equals(message_assist_call, common_chat_parse( - "{\n \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", - COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals( + simple_assist_msg("", "", "python", ""), + common_chat_parse( + "```json\n" + " { \"name\" : \"python\"", + /* is_partial= */ true, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + assert_msg_equals( + simple_assist_msg("Let's call something\n"), + common_chat_parse( + "Let's call something\n" + "{\"name\"", + /* is_partial= */ true, + { + /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + })); + assert_msg_equals( + simple_assist_msg("Let's call something\n"), + common_chat_parse( + "Let's call something\n" + "{\"name", + /* is_partial= */ true, + { + /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + })); + assert_msg_equals(message_assist_call_thoughts, + common_chat_parse( + // QwQ-32B's template adds a trailing if add_generation_prompt + "I'm\nthinking\n" + "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ true, + })); + assert_msg_equals( + message_assist_call, + common_chat_parse( + "\n" + "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + assert_msg_equals(message_assist_call_content, + common_chat_parse( + "Hello, world!\nWhat's up?\n" + "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + assert_msg_equals( + message_assist_call, + common_chat_parse( + "{\"arg1\": 1}", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + assert_msg_equals( + message_assist_call, + common_chat_parse( + "\n" + "{\"arg1\": 1}\n" + "", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + assert_msg_equals( + message_assist_call, + common_chat_parse( + "\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + assert_msg_equals( + message_assist_call, + common_chat_parse( + "\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + assert_msg_equals( + message_assist_call, + common_chat_parse( + "\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + assert_msg_equals( + message_assist_call, + common_chat_parse( + "```xml\n" + "\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "\n" + "```", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + assert_msg_equals( + message_assist_call, + common_chat_parse( + "```xml\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "```", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + assert_msg_equals( + message_assist_call, + common_chat_parse( + "```\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "```", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + assert_msg_equals( + message_assist_call, + common_chat_parse( + "```\n" + "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "```", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + assert_msg_equals( + message_assist_call, + common_chat_parse( + "```json\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "```", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + assert_msg_equals( + message_assist_call, + common_chat_parse( + "```json\n" + "\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}} \n" + " \n" + "``` ", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + assert_msg_equals( + message_assist_call, + common_chat_parse( + "\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + assert_msg_equals( + message_assist_call, + common_chat_parse( + "\n" + " {\n" + " \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}\n" + " }\n" + "", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + assert_msg_equals( + message_assist_call, + common_chat_parse( + "\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + assert_msg_equals( + message_assist_call, + common_chat_parse( + "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + assert_msg_equals( + message_assist_call, + common_chat_parse( + "{\n \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); - assert_msg_equals(message_assist_thoughts_unparsed_think, - common_chat_parse("I'm thinkingHello, world!\nWhat's up?", - COMMON_CHAT_FORMAT_HERMES_2_PRO)); - assert_msg_equals(message_assist_thoughts_unparsed_think, - common_chat_parse("I'm thinkingHello, world!\nWhat's up?", - COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals( + simple_assist_msg( + "This is not a tool call:", + "", + "special_function", + "{\"arg1\": 1}"), + common_chat_parse( + "This is not a tool call:\n" + "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + assert_msg_equals(message_assist, + common_chat_parse( + "Hello, world!\nWhat's up?", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + assert_msg_equals(message_assist_thoughts_unparsed_deepseek, + common_chat_parse( + "I'm\nthinkingHello, world!\nWhat's up?", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + // assert_msg_equals(message_assist_thoughts_unparsed_deepseek, + // common_chat_parse( + // "I'm\nthinkingHello, world!\nWhat's up?", + // COMMON_CHAT_FORMAT_HERMES_2_PRO)); assert_msg_equals(message_assist_thoughts, - common_chat_parse("I'm thinkingHello, world!\nWhat's up?", - COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING)); + common_chat_parse( + "I'm\nthinkingHello, world!\nWhat's up?", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + })); assert_msg_equals(message_assist_thoughts, - common_chat_parse("I'm thinkingHello, world!\nWhat's up?", - COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING)); + common_chat_parse( + "I'm\nthinkingHello, world!\nWhat's up?", + /* is_partial= */ true, + { + /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + })); + assert_msg_equals(message_assist_thoughts_unparsed_md, + common_chat_parse( + "I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}```", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ true, + /* .thinking_forced_open = */ false, + /* .parse_tool_calls = */ false, + })); + assert_msg_equals(message_assist_thoughts_unparsed_md_partial, + common_chat_parse( + "I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}```", + /* is_partial= */ true, + { + /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ true, + /* .thinking_forced_open = */ false, + })); + assert_msg_equals(message_assist_thoughts_unopened_unparsed, + common_chat_parse( + "I'm\nthinkingHello, world!\nWhat's up?", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + })); + assert_msg_equals(message_assist_thoughts, + common_chat_parse( + "I'm\nthinkingHello, world!\nWhat's up?", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ true, + })); test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "\n" "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" ""); - test_templates(tmpls.get(), end_tokens, message_assist_call_python, tools, + test_templates(tmpls.get(), end_tokens, message_assist_call_python_lines, tools, "\n" - "{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n" + "{\"name\": \"python\", \"arguments\": {\"code\":\"# This is a program:\\nprint('hey')\"}}\n" ""); + assert_msg_equals( + simple_assist_msg("", /* reasoning_content= */ "nah uhg"), + common_chat_parse( + "nah uhg", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + })); } { auto tmpls = read_templates("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"); @@ -806,6 +1065,13 @@ static void test_template_output_parsers() { inputs_tools_builtin) .format); + assert_equals( + message_assist_call, + common_chat_parse( + "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LLAMA_3_X})); + // test_templates(tmpls.get(), end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false); test_templates(tmpls.get(), end_tokens, message_assist_call_code_interpreter, llama_3_1_tools, "<|python_tag|>code_interpreter.call(code=\"print('hey')\")"); @@ -836,6 +1102,22 @@ static void test_template_output_parsers() { assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); + for (auto is_partial : { false, true }) { + assert_equals( + message_assist_call, + common_chat_parse( + "{\"arg1\": 1}", + is_partial, + {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1})); + } + + assert_equals( + message_assist_call, + common_chat_parse( + "{\"arg1\": 1}<", + /* is_partial= */ true, + {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1})); + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "{\"arg1\": 1}"); @@ -847,6 +1129,47 @@ static void test_template_output_parsers() { assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_tools).format); + assert_msg_equals( + simple_assist_msg( + "Hello, world!\nnono\nWhat's up?", + "", + "special_function", + "{\"arg1\": 1}"), + common_chat_parse( + "all\n" + "Hello, world!\n" + "nono\n" + "What's up?>>>special_function\n" + "{\"arg1\": 1}\n", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2})); + assert_msg_equals(message_assist_call_python_lines, + common_chat_parse( + "python\n" + "# This is a program:\n" + "print('hey')", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2})); + assert_msg_equals(message_assist_call_python_lines_unclosed, + common_chat_parse( + "python\n" + "# This is a program:\n" + "print('hey')", + /* is_partial= */ true, + {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2})); + assert_msg_equals(message_assist_call, + common_chat_parse( + "special_function\n" + "{\"arg1\": 1} \n ", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2})); + assert_msg_equals(message_assist, + common_chat_parse( + "all\n" + "Hello, world!\nWhat's up?", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2})); + test_templates(tmpls.get(), end_tokens, message_assist, {}, "all\n" "Hello, world!\n" @@ -872,22 +1195,73 @@ static void test_template_output_parsers() { auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"); std::vector end_tokens{ "<|end▁of▁sentence|>" }; - assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format); + for (const auto & inputs : { inputs_no_tools, inputs_tools }) { + auto params = common_chat_templates_apply(tmpls.get(), inputs); + assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, params.format); + assert_equals(true, params.thinking_forced_open); + } test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - assert_msg_equals(message_assist_thoughts_unparsed_think, - common_chat_parse("I'm thinkingHello, world!\nWhat's up?", - COMMON_CHAT_FORMAT_DEEPSEEK_R1)); + assert_msg_equals( + simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking"), + common_chat_parse( + "I'm\nthinkingHello, world!\nWhat's up?", + /* is_partial= */ false, + { + COMMON_CHAT_FORMAT_DEEPSEEK_R1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ true, + })); + assert_msg_equals( + simple_assist_msg("", "I need to remember the correct syntax. It starts with <|tool▁calls▁begin|> and ends with"), + common_chat_parse( + "I need to remember the correct syntax. It starts with <|tool▁calls▁begin|> and ends with", + /* is_partial= */ true, + { + COMMON_CHAT_FORMAT_DEEPSEEK_R1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ true, + })); assert_msg_equals(message_assist_thoughts, - common_chat_parse("I'm thinkingHello, world!\nWhat's up?", - COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING)); + common_chat_parse( + "I'm\nthinkingHello, world!\nWhat's up?", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + })); + assert_msg_equals(message_assist_thoughts_unopened_unparsed, + common_chat_parse( + "I'm\nthinkingHello, world!\nWhat's up?", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + })); + assert_msg_equals(message_assist_thoughts, + common_chat_parse( + "I'm\nthinkingHello, world!\nWhat's up?", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ true, + })); assert_msg_equals(message_assist_thoughts, // Latest template update (ast of 20250209) adds a trailing \n if add_generation_prompt is true. - common_chat_parse("I'm thinkingHello, world!\nWhat's up?", - COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING)); + common_chat_parse( + "I'm\nthinkingHello, world!\nWhat's up?", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ true, + })); // test_templates(tmpls.get(), end_tokens, message_assist_call, tools, // "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" // "```json\n" @@ -904,16 +1278,32 @@ static void test_template_output_parsers() { assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format); test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - assert_msg_equals(message_assist_thoughts_unparsed_think, - common_chat_parse("I'm thinkingHello, world!\nWhat's up?", - COMMON_CHAT_FORMAT_DEEPSEEK_R1)); + assert_msg_equals(message_assist_thoughts_unparsed_deepseek, + common_chat_parse( + "I'm\nthinkingHello, world!\nWhat's up?", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_DEEPSEEK_R1})); assert_msg_equals(message_assist_thoughts, - common_chat_parse("I'm thinkingHello, world!\nWhat's up?", - COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING)); + common_chat_parse( + "I'm\nthinkingHello, world!\nWhat's up?", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + })); + assert_msg_equals(message_assist_thoughts, + common_chat_parse( + "I'm\nthinkingHello, world!\nWhat's up?", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ true, + })); assert_msg_equals(message_assist_call_thoughts_unparsed, common_chat_parse( @@ -922,7 +1312,17 @@ static void test_template_output_parsers() { "```json\n" "{\"arg1\": 1}\n" "```<|tool▁call▁end|><|tool▁calls▁end|>", - COMMON_CHAT_FORMAT_DEEPSEEK_R1)); + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_DEEPSEEK_R1})); + assert_msg_equals(message_assist_call, + common_chat_parse( + "<|tool▁calls|>function<|tool▁sep|>special_function\n" + "```json\n" + "{\"arg1\": 1}\n" + "```<|tool▁call▁end|><|tool▁calls▁end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_DEEPSEEK_R1})); + assert_msg_equals(message_assist_call_thoughts, common_chat_parse( "I'm\nthinking\n\n" @@ -930,7 +1330,11 @@ static void test_template_output_parsers() { "```json\n" "{\"arg1\": 1}\n" "```<|tool▁call▁end|><|tool▁calls▁end|>", - COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING)); + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + })); test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" "```json\n" @@ -939,6 +1343,90 @@ static void test_template_output_parsers() { } } +static void test_msg_diffs_compute() { + printf("[%s]\n", __func__); + { + common_chat_msg msg1; + + common_chat_msg msg2; + msg2.content = "Hello, world!"; + + common_chat_msg_diff diff; + diff.content_delta = "Hello, world!"; + + assert_equals( + {diff}, + common_chat_msg_diff::compute_diffs(msg1, msg2)); + } + { + common_chat_msg msg1; + msg1.content = "Hello,"; + + common_chat_msg msg2; + msg2.content = "Hello, world!"; + + common_chat_msg_diff diff; + diff.content_delta = " world!"; + + assert_equals( + {diff}, + common_chat_msg_diff::compute_diffs(msg1, msg2)); + } + { + common_chat_msg msg0; + + common_chat_msg msg1; + msg1.tool_calls = { { "special_function", "{\"ar", /* .id = */ "123" } }; + + common_chat_msg msg2; + msg2.tool_calls = { { "special_function", "{\"arg1\": 1}", /* .id = */ "123" } }; + + common_chat_msg_diff diff01; + diff01.tool_call_index = 0; + diff01.tool_call_delta.name = "special_function"; + diff01.tool_call_delta.id = "123"; + diff01.tool_call_delta.arguments = "{\"ar"; + + assert_equals( + {diff01}, + common_chat_msg_diff::compute_diffs(msg0, msg1)); + + common_chat_msg_diff diff12; + diff12.tool_call_index = 0; + // Note: neither id nor name change here. + diff12.tool_call_delta.arguments = "g1\": 1}"; + + assert_equals( + {diff12}, + common_chat_msg_diff::compute_diffs(msg1, msg2)); + } + { + common_chat_msg msg0; + + common_chat_msg msg2; + msg2.tool_calls = { + { "f1", "{\"arg1\": 1}", /* .id = */ "123" }, + { "f2", "{\"arg2\": 2}", /* .id = */ "222" }, + }; + + common_chat_msg_diff diff1; + diff1.tool_call_index = 0; + diff1.tool_call_delta.name = "f1"; + diff1.tool_call_delta.id = "123"; + diff1.tool_call_delta.arguments = "{\"arg1\": 1}"; + + common_chat_msg_diff diff2; + diff2.tool_call_index = 1; + diff2.tool_call_delta.name = "f2"; + diff2.tool_call_delta.id = "222"; + diff2.tool_call_delta.arguments = "{\"arg2\": 2}"; + + assert_equals( + {diff1, diff2}, + common_chat_msg_diff::compute_diffs(msg0, msg2)); + } +} + int main(int argc, char ** argv) { // try { #ifndef _WIN32 @@ -972,6 +1460,7 @@ int main(int argc, char ** argv) { } else #endif { + test_msg_diffs_compute(); test_msgs_oaicompat_json_conversion(); test_tools_oaicompat_json_conversion(); test_template_output_parsers(); diff --git a/tests/test-gguf.cpp b/tests/test-gguf.cpp index eaf572c66..3f0c312e2 100644 --- a/tests/test-gguf.cpp +++ b/tests/test-gguf.cpp @@ -16,6 +16,7 @@ constexpr int offset_has_data = 3000; enum handcrafted_file_type { HANDCRAFTED_HEADER_BAD_MAGIC = 10, + HANDCRAFTED_HEADER_BAD_VERSION_0 = 15, HANDCRAFTED_HEADER_BAD_VERSION_1 = 20, HANDCRAFTED_HEADER_BAD_VERSION_FUTURE = 30, HANDCRAFTED_HEADER_BAD_N_TENSORS = 40, @@ -51,6 +52,7 @@ enum handcrafted_file_type { static std::string handcrafted_file_type_name(const enum handcrafted_file_type hft) { switch (hft) { case HANDCRAFTED_HEADER_BAD_MAGIC: return "HEADER_BAD_MAGIC"; + case HANDCRAFTED_HEADER_BAD_VERSION_0: return "HEADER_BAD_VERSION_0"; case HANDCRAFTED_HEADER_BAD_VERSION_1: return "HEADER_BAD_VERSION_1"; case HANDCRAFTED_HEADER_BAD_VERSION_FUTURE: return "HEADER_BAD_VERSION_FUTURE"; case HANDCRAFTED_HEADER_BAD_N_KV: return "HEADER_BAD_N_KV"; @@ -171,7 +173,10 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft helper_write(file, GGUF_MAGIC, 4); } - if (hft == HANDCRAFTED_HEADER_BAD_VERSION_1) { + if (hft == HANDCRAFTED_HEADER_BAD_VERSION_0) { + const uint32_t version = 0; + helper_write(file, version); + } else if (hft == HANDCRAFTED_HEADER_BAD_VERSION_1) { const uint32_t version = 1; helper_write(file, version); } else if (hft == HANDCRAFTED_HEADER_BAD_VERSION_FUTURE) { @@ -660,6 +665,7 @@ static std::pair test_handcrafted_file(const unsigned int seed) { const std::vector hfts = { HANDCRAFTED_HEADER_BAD_MAGIC, + HANDCRAFTED_HEADER_BAD_VERSION_0, HANDCRAFTED_HEADER_BAD_VERSION_1, HANDCRAFTED_HEADER_BAD_VERSION_FUTURE, HANDCRAFTED_HEADER_BAD_N_KV, diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 8988c347e..6d64f0737 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -7,6 +7,8 @@ #include "../src/unicode.h" #include "../src/llama-grammar.h" +#include + #include #include #include diff --git a/tests/test-json-partial.cpp b/tests/test-json-partial.cpp new file mode 100644 index 000000000..bc136bece --- /dev/null +++ b/tests/test-json-partial.cpp @@ -0,0 +1,237 @@ +#include "common.h" +#include "json-partial.h" +#include +#include +#include + +template static void assert_equals(const T & expected, const T & actual) { + if (expected != actual) { + std::cerr << "Expected: " << expected << std::endl; + std::cerr << "Actual: " << actual << std::endl; + std::cerr << std::flush; + throw std::runtime_error("Test failed"); + } +} + +static void test_json_healing() { + auto parse = [](const std::string & str) { + std::cerr << "# Parsing: " << str << '\n'; + std::string::const_iterator it = str.begin(); + const auto end = str.end(); + common_json out; + std::string healing_marker = "$llama.cpp.json$"; + if (common_json_parse(it, end, healing_marker, out)) { + auto dump = out.json.dump(); + std::cerr << "Parsed: " << dump << '\n'; + std::cerr << "Magic: " << out.healing_marker.json_dump_marker << '\n'; + std::string result; + if (!out.healing_marker.json_dump_marker.empty()) { + auto i = dump.find(out.healing_marker.json_dump_marker); + if (i == std::string::npos) { + throw std::runtime_error("Failed to find magic in dump " + dump + " (magic: " + out.healing_marker.json_dump_marker + ")"); + } + result = dump.substr(0, i); + } else { + result = dump; + } + std::cerr << "Result: " << result << '\n'; + if (string_starts_with(str, result)) { + std::cerr << "Failure!\n"; + } + // return dump; + } else { + throw std::runtime_error("Failed to parse: " + str); + } + + }; + auto parse_all = [&](const std::string & str) { + for (size_t i = 1; i < str.size(); i++) { + parse(str.substr(0, i)); + } + }; + parse_all("{\"a\": \"b\"}"); + parse_all("{\"hey\": 1, \"ho\\\"ha\": [1]}"); + + parse_all("[{\"a\": \"b\"}]"); + + auto test = [&](const std::vector & inputs, const std::string & expected, const std::string & expected_marker) { + for (const auto & input : inputs) { + common_json out; + assert_equals(true, common_json_parse(input, "$foo", out)); + assert_equals(expected, out.json.dump()); + assert_equals(expected_marker, out.healing_marker.json_dump_marker); + } + }; + // No healing needed: + test( + { + R"([{"a":"b"}, "y"])", + }, + R"([{"a":"b"},"y"])", + "" + ); + // Partial literals can't be healed: + test( + { + R"([1)", + R"([tru)", + R"([n)", + R"([nul)", + R"([23.2)", + }, + R"(["$foo"])", + R"("$foo)" + ); + test( + { + R"({"a": 1)", + R"({"a": tru)", + R"({"a": n)", + R"({"a": nul)", + R"({"a": 23.2)", + }, + R"({"a":"$foo"})", + R"("$foo)" + ); + test( + { + R"({)", + }, + R"({"$foo":1})", + R"("$foo)" + ); + test( + { + R"([)", + }, + R"(["$foo"])", + R"("$foo)" + ); + // Healing right after a full literal + test( + { + R"(1 )", + }, + R"(1)", + "" + ); + test( + { + R"(true)", + R"(true )", + }, + R"(true)", + "" + ); + test( + { + R"(null)", + R"(null )", + }, + R"(null)", + "" + ); + test( + { + R"([1 )", + }, + R"([1,"$foo"])", + R"(,"$foo)" + ); + test( + { + R"([{})", + R"([{} )", + }, + R"([{},"$foo"])", + R"(,"$foo)" + ); + test( + { + R"([true)", + }, + // TODO: detect the true/false/null literal was complete + R"(["$foo"])", + R"("$foo)" + ); + test( + { + R"([true )", + }, + R"([true,"$foo"])", + R"(,"$foo)" + ); + test( + { + R"([true,)", + }, + R"([true,"$foo"])", + R"("$foo)" + ); + // Test nesting + test( + { + R"([{"a": [{"b": [{)", + }, + R"([{"a":[{"b":[{"$foo":1}]}]}])", + R"("$foo)" + ); + test( + { + R"([{"a": [{"b": [)", + }, + R"([{"a":[{"b":["$foo"]}]}])", + R"("$foo)" + ); + + test( + { + R"([{"a": "b"})", + R"([{"a": "b"} )", + }, + R"([{"a":"b"},"$foo"])", + R"(,"$foo)" + ); + test( + { + R"([{"a": "b"},)", + R"([{"a": "b"}, )", + }, + R"([{"a":"b"},"$foo"])", + R"("$foo)" + ); + test( + { + R"({ "code)", + }, + R"({"code$foo":1})", + R"($foo)" + ); + test( + { + R"({ "code\)", + }, + R"({"code\\$foo":1})", + R"(\$foo)" + ); + test( + { + R"({ "code")", + }, + R"({"code":"$foo"})", + R"(:"$foo)" + ); + test( + { + R"({ "key")", + }, + R"({"key":"$foo"})", + R"(:"$foo)" + ); +} + +int main() { + test_json_healing(); + std::cerr << "All tests passed.\n"; + return 0; +} diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index 38cf01d6d..78ee55e24 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -6,6 +6,8 @@ #include "../src/llama-grammar.h" +#include + #include #include #include diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 60ac62b38..6300f25ca 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -98,7 +98,7 @@ static void test_top_p(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector numa mode (default: disabled)\n"); printf(" -r, --repetitions number of times to repeat each test (default: %d)\n", cmd_params_defaults.reps); - printf(" --prio <0|1|2|3> process/thread priority (default: %d)\n", + printf(" --prio <-1|0|1|2|3> process/thread priority (default: %d)\n", cmd_params_defaults.prio); printf(" --delay <0...N> (seconds) delay between each test (default: %d)\n", cmd_params_defaults.delay); diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index e7ba23587..4baa15b96 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -1,37 +1,50 @@ # mtmd -add_library(mtmd OBJECT +find_package(Threads REQUIRED) + +add_library(mtmd mtmd.cpp - mtmd-helper.cpp + mtmd-audio.cpp mtmd.h clip.cpp clip.h clip-impl.h + mtmd-helper.cpp + mtmd-helper.h ) -target_link_libraries(mtmd PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT}) - -target_include_directories(mtmd PUBLIC .) +target_link_libraries (mtmd PUBLIC ggml llama) +target_link_libraries (mtmd PRIVATE Threads::Threads) +target_include_directories(mtmd PUBLIC .) target_include_directories(mtmd PRIVATE ../..) -target_include_directories(mtmd PRIVATE ../../common) # for stb_image.h +target_include_directories(mtmd PRIVATE ../../vendor) +target_compile_features (mtmd PRIVATE cxx_std_17) -target_compile_features(mtmd PRIVATE cxx_std_17) - -add_library(mtmd_static STATIC $) if (BUILD_SHARED_LIBS) - set_target_properties(mtmd PROPERTIES POSITION_INDEPENDENT_CODE ON) - target_compile_definitions(mtmd PRIVATE LLAMA_SHARED LLAMA_BUILD) - add_library(mtmd_shared SHARED $) - target_link_libraries(mtmd_shared PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT}) - install(TARGETS mtmd_shared LIBRARY) + set_target_properties (mtmd PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_compile_definitions(mtmd PRIVATE LLAMA_BUILD) + target_compile_definitions(mtmd PUBLIC LLAMA_SHARED) endif() +set(MTMD_PUBLIC_HEADERS + ${CMAKE_CURRENT_SOURCE_DIR}/mtmd.h + ${CMAKE_CURRENT_SOURCE_DIR}/mtmd-helper.h + ) + +set_target_properties(mtmd + PROPERTIES + PUBLIC_HEADER "${MTMD_PUBLIC_HEADERS}") + +install(TARGETS mtmd LIBRARY PUBLIC_HEADER) + if (NOT MSVC) - target_compile_options(mtmd PRIVATE -Wno-cast-qual) # stb_image.h + # for stb_image.h and miniaudio.h + target_compile_options(mtmd PRIVATE -Wno-cast-qual) endif() -if(TARGET BUILD_INFO) - add_dependencies(mtmd BUILD_INFO) +if (TARGET BUILD_INFO) + add_dependencies(mtmd BUILD_INFO) + add_dependencies(mtmd-helper BUILD_INFO) endif() add_executable(llama-llava-cli deprecation-warning.cpp) @@ -40,8 +53,8 @@ add_executable(llama-minicpmv-cli deprecation-warning.cpp) add_executable(llama-qwen2vl-cli deprecation-warning.cpp) set(TARGET llama-mtmd-cli) -add_executable(${TARGET} mtmd-cli.cpp) -set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-mtmd-cli) -install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE common mtmd ${CMAKE_THREAD_LIBS_INIT}) +add_executable (${TARGET} mtmd-cli.cpp) +set_target_properties (${TARGET} PROPERTIES OUTPUT_NAME llama-mtmd-cli) +install (TARGETS ${TARGET} RUNTIME) +target_link_libraries (${TARGET} PRIVATE common mtmd Threads::Threads) target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 7b7d2df39..62c936ed0 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -16,22 +16,26 @@ #define KEY_FTYPE "general.file_type" #define KEY_NAME "general.name" #define KEY_DESCRIPTION "general.description" -#define KEY_MINICPMV_VERSION "clip.minicpmv_version" +#define KEY_PROJ_TYPE "clip.projector_type" +#define KEY_HAS_AUDIO_ENC "clip.has_audio_encoder" +#define KEY_HAS_VISION_ENC "clip.has_vision_encoder" #define KEY_USE_GELU "clip.use_gelu" #define KEY_USE_SILU "clip.use_silu" -#define KEY_N_EMBD "clip.vision.embedding_length" -#define KEY_N_FF "clip.vision.feed_forward_length" -#define KEY_N_BLOCK "clip.vision.block_count" -#define KEY_N_HEAD "clip.vision.attention.head_count" -#define KEY_LAYER_NORM_EPS "clip.vision.attention.layer_norm_epsilon" -#define KEY_PROJ_DIM "clip.vision.projection_dim" + +#define KEY_N_EMBD "clip.%s.embedding_length" +#define KEY_N_FF "clip.%s.feed_forward_length" +#define KEY_N_BLOCK "clip.%s.block_count" +#define KEY_PROJ_DIM "clip.%s.projection_dim" +#define KEY_N_HEAD "clip.%s.attention.head_count" +#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon" + +// vision-specific #define KEY_IMAGE_SIZE "clip.vision.image_size" #define KEY_PATCH_SIZE "clip.vision.patch_size" #define KEY_IMAGE_MEAN "clip.vision.image_mean" #define KEY_IMAGE_STD "clip.vision.image_std" #define KEY_FEATURE_LAYER "clip.vision.feature_layer" #define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor" -#define KEY_PROJ_TYPE "clip.projector_type" #define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size" #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" @@ -39,13 +43,18 @@ #define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" #define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern" #define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size" +#define KEY_MINICPMV_VERSION "clip.minicpmv_version" + +// audio-specific +#define KEY_A_NUM_MEL_BINS "clip.audio.num_mel_bins" +#define KEY_A_PROJ_STACK_FACTOR "clip.audio.projector.stack_factor" // // tensor name constants // -#define TN_POS_EMBD "v.position_embd.weight" +#define TN_POS_EMBD "%s.position_embd.weight" #define TN_CLASS_EMBD "v.class_embd" #define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat #define TN_PATCH_EMBD_1 "v.patch_embd.weight.1" @@ -95,6 +104,13 @@ #define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s" #define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s" +// ultravox +#define TN_CONV1D "a.conv1d.%d.%s" +#define TN_MM_AUDIO_MLP "mm.a.mlp.%d.%s" +#define TN_MM_AUDIO_FC "mm.a.fc.%s" // fully connected layer +#define TN_MM_NORM_PRE "mm.a.norm_pre.%s" +#define TN_MM_NORM_MID "mm.a.norm_mid.%s" + // align x to upper multiple of n #define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n)) @@ -110,8 +126,11 @@ enum projector_type { PROJECTOR_TYPE_IDEFICS3, PROJECTOR_TYPE_PIXTRAL, PROJECTOR_TYPE_QWEN25VL, + PROJECTOR_TYPE_ULTRAVOX, PROJECTOR_TYPE_INTERNVL, PROJECTOR_TYPE_LLAMA4, + PROJECTOR_TYPE_QWEN2A, + PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx PROJECTOR_TYPE_UNKNOWN, }; @@ -126,8 +145,11 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_GEMMA3, "gemma3"}, { PROJECTOR_TYPE_IDEFICS3, "idefics3"}, { PROJECTOR_TYPE_PIXTRAL, "pixtral"}, + { PROJECTOR_TYPE_ULTRAVOX, "ultravox"}, { PROJECTOR_TYPE_INTERNVL, "internvl"}, { PROJECTOR_TYPE_LLAMA4, "llama4"}, + { PROJECTOR_TYPE_QWEN2A, "qwen2a"}, + { PROJECTOR_TYPE_QWEN25O, "qwen2.5o"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { @@ -147,8 +169,10 @@ struct clip_image_u8 { std::vector buf; }; -// RGB float32 image (NHWC) -// Memory layout: RGBRGBRGB... +// For images, buf.size() == nx*ny*3 +// Memory layout: RGBRGBRGB... +// For audio, only one channel is used, buf.size() == nx*ny +// nx will be n_frames and ny will be n_mel struct clip_image_f32 { int nx; int ny; @@ -242,6 +266,7 @@ struct clip_image_u8_batch { struct clip_image_f32_batch { std::vector entries; + bool is_audio = false; // for llava-uhd style models, we need to know the grid size // note: entries.size() == grid_x * grid_y + 1 (one overview image) @@ -249,7 +274,12 @@ struct clip_image_f32_batch { int grid_y = 0; clip_image_f32_batch clone() const { - clip_image_f32_batch new_batch; + clip_image_f32_batch new_batch{ + /* entries */ {}, + /* is_audio */ is_audio, + /* grid_x */ grid_x, + /* grid_y */ grid_y, + }; new_batch.entries.reserve(entries.size()); for (const auto & entry : entries) { new_batch.entries.emplace_back(new clip_image_f32(*entry)); diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index eba07f6c8..c25bacc17 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -11,9 +11,6 @@ #include "ggml-backend.h" #include "gguf.h" -#define STB_IMAGE_IMPLEMENTATION -#include "stb_image.h" - #include #include #include @@ -35,6 +32,7 @@ struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callbac enum ffn_op_type { FFN_GELU, + FFN_GELU_ERF, FFN_SILU, FFN_GELU_QUICK, }; @@ -174,9 +172,13 @@ struct clip_hparams { int32_t n_layer; int32_t proj_scale_factor = 0; // idefics3 + float image_mean[3]; + float image_std[3]; + // for models using dynamic image size, we need to have a smaller image size to warmup // otherwise, user will get OOM everytime they load the model int32_t warmup_image_size = 0; + int32_t warmup_audio_size = 3000; ffn_op_type ffn_op = FFN_GELU; @@ -191,6 +193,14 @@ struct clip_hparams { int32_t attn_window_size = 0; int32_t n_wa_pattern = 0; int32_t spatial_merge_size = 0; + + // audio + int32_t n_mel_bins = 0; // whisper preprocessor + int32_t proj_stack_factor = 0; // ultravox + + // legacy + bool has_llava_projector = false; + int minicpmv_version = 0; }; struct clip_layer { @@ -228,8 +238,10 @@ struct clip_layer { ggml_tensor * ls_2_w = nullptr; }; -struct clip_vision_model { - struct clip_hparams hparams; +struct clip_model { + clip_modality modality = CLIP_MODALITY_VISION; + projector_type proj_type = PROJECTOR_TYPE_MLP; + clip_hparams hparams; // embeddings ggml_tensor * class_embedding = nullptr; @@ -246,7 +258,9 @@ struct clip_vision_model { ggml_tensor * post_ln_w; ggml_tensor * post_ln_b; - ggml_tensor * projection; + ggml_tensor * projection; // TODO: rename it to fc (fully connected layer) + ggml_tensor * mm_fc_w; + ggml_tensor * mm_fc_b; // LLaVA projection ggml_tensor * mm_input_norm_w = nullptr; @@ -332,17 +346,18 @@ struct clip_vision_model { // pixtral ggml_tensor * token_embd_img_break = nullptr; ggml_tensor * mm_patch_merger_w = nullptr; + + // ultravox / whisper encoder + ggml_tensor * conv1d_1_w = nullptr; + ggml_tensor * conv1d_1_b = nullptr; + ggml_tensor * conv1d_2_w = nullptr; + ggml_tensor * conv1d_2_b = nullptr; + ggml_tensor * mm_norm_pre_w = nullptr; + ggml_tensor * mm_norm_mid_w = nullptr; }; struct clip_ctx { - bool has_llava_projector = false; - int minicpmv_version = 0; - - struct clip_vision_model vision_model; - projector_type proj_type = PROJECTOR_TYPE_MLP; - - float image_mean[3]; - float image_std[3]; + clip_model model; gguf_context_ptr ctx_gguf; ggml_context_ptr ctx_data; @@ -396,11 +411,16 @@ struct clip_ctx { ggml_backend_free(backend_cpu); } } + + // this function is added so that we don't change too much of the existing code + projector_type proj_type() const { + return model.proj_type; + } }; struct clip_graph { clip_ctx * ctx; - const clip_vision_model & model; + const clip_model & model; const clip_hparams & hparams; // we only support single image per batch @@ -423,7 +443,7 @@ struct clip_graph { clip_graph(clip_ctx * ctx, const clip_image_f32 & img) : ctx(ctx), - model(ctx->vision_model), + model(ctx->model), hparams(model.hparams), img(img), patch_size(hparams.patch_size), @@ -455,7 +475,7 @@ struct clip_graph { model.position_embeddings, nullptr); - if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { + if (ctx->proj_type() == PROJECTOR_TYPE_GEMMA3) { const int batch_size = 1; GGML_ASSERT(n_patches_x == n_patches_y); const int patches_per_image = n_patches_x; @@ -478,7 +498,7 @@ struct clip_graph { ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)), cur); - } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) { + } else if (ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3) { // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578 const int scale_factor = model.hparams.proj_scale_factor; @@ -612,7 +632,7 @@ struct clip_graph { const int n_pos = n_patches; const int num_position_ids = n_pos * 4; // m-rope requires 4 dim per position - norm_type norm_t = ctx->proj_type == PROJECTOR_TYPE_QWEN25VL + norm_type norm_t = ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL ? NORM_TYPE_RMS // qwen 2.5 vl : NORM_TYPE_NORMAL; // qwen 2 vl @@ -828,11 +848,11 @@ struct clip_graph { const int d_head = 128; int n_head = n_embd/d_head; int num_query = 96; - if (ctx->minicpmv_version == 2) { + if (ctx->model.hparams.minicpmv_version == 2) { num_query = 96; - } else if (ctx->minicpmv_version == 3) { + } else if (ctx->model.hparams.minicpmv_version == 3) { num_query = 64; - } else if (ctx->minicpmv_version == 4) { + } else if (ctx->model.hparams.minicpmv_version == 4) { num_query = 64; } @@ -1049,7 +1069,7 @@ struct clip_graph { int il_last = hparams.n_layer - 1; int deepest_feature_layer = -1; - if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { + if (ctx->proj_type() == PROJECTOR_TYPE_MINICPMV || ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE) { il_last += 1; } @@ -1183,7 +1203,7 @@ struct clip_graph { } // llava projector (also used by granite) - if (ctx->has_llava_projector) { + if (ctx->model.hparams.has_llava_projector) { embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); @@ -1197,7 +1217,7 @@ struct clip_graph { // print_tensor_info(embeddings, "embeddings"); // llava projector - if (ctx->proj_type == PROJECTOR_TYPE_MLP) { + if (ctx->proj_type() == PROJECTOR_TYPE_MLP) { embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); @@ -1207,7 +1227,7 @@ struct clip_graph { embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); } } - else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { + else if (ctx->proj_type() == PROJECTOR_TYPE_MLP_NORM) { embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false); @@ -1228,7 +1248,7 @@ struct clip_graph { embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_4_w), model.mm_4_b); } - else if (ctx->proj_type == PROJECTOR_TYPE_LDP) { + else if (ctx->proj_type() == PROJECTOR_TYPE_LDP) { // MobileVLM projector int n_patch = 24; ggml_tensor * mlp_1 = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, embeddings); @@ -1338,7 +1358,7 @@ struct clip_graph { } embeddings = block_1; } - else if (ctx->proj_type == PROJECTOR_TYPE_LDPV2) + else if (ctx->proj_type() == PROJECTOR_TYPE_LDPV2) { int n_patch = 24; ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings); @@ -1368,7 +1388,7 @@ struct clip_graph { } // glm projector - else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { + else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE) { size_t gridsz = (size_t)sqrt(embeddings->ne[1]); embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3)); embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]); @@ -1408,6 +1428,114 @@ struct clip_graph { return gf; } + // whisper encoder with custom projector + ggml_cgraph * build_whisper_enc() { + const int n_frames = img.nx; + const int n_pos = n_frames / 2; + GGML_ASSERT(model.position_embeddings->ne[1] >= n_pos); + + ggml_tensor * inp = build_inp_raw(1); + + // conv1d block + { + // convolution + gelu + ggml_tensor * cur = ggml_conv_1d_ph(ctx0, model.conv1d_1_w, inp, 1, 1); + cur = ggml_add(ctx0, cur, model.conv1d_1_b); + + cur = ggml_gelu_erf(ctx0, cur); + + cur = ggml_conv_1d_ph(ctx0, model.conv1d_2_w, cur, 2, 1); + cur = ggml_add(ctx0, cur, model.conv1d_2_b); + + cur = ggml_gelu_erf(ctx0, cur); + // transpose + inp = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); + cb(inp, "after_conv1d", -1); + } + + // sanity check (only check one layer, but it should be the same for all) + GGML_ASSERT(model.layers[0].ln_1_w && model.layers[0].ln_1_b); + GGML_ASSERT(model.layers[0].ln_2_w && model.layers[0].ln_2_b); + GGML_ASSERT(model.layers[0].q_b); + GGML_ASSERT(model.layers[0].v_b); + GGML_ASSERT(!model.layers[0].k_b); // no bias for k + GGML_ASSERT(model.post_ln_w && model.post_ln_b); + + ggml_tensor * pos_embd_selected = ggml_view_2d( + ctx0, model.position_embeddings, + model.position_embeddings->ne[0], n_pos, + model.position_embeddings->nb[1], 0 + ); + ggml_tensor * cur = build_vit( + inp, n_pos, + NORM_TYPE_NORMAL, + hparams.ffn_op, + pos_embd_selected, + nullptr); + + cb(cur, "after_transformer", -1); + + if (ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX) { + // StackAudioFrames + // https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py + { + int64_t stride = n_embd * hparams.proj_stack_factor; + int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride); + int64_t pad = padded_len - ggml_nelements(cur); + if (pad > 0) { + cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0); + cur = ggml_pad(ctx0, cur, pad, 0, 0, 0); + } + cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride, + ggml_row_size(cur->type, stride), 0); + } + + cb(cur, "after_stacked", -1); + + // UltravoxProjector + { + // pre-norm + cur = ggml_rms_norm(ctx0, cur, 1e-6); + cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w); + + // ffn in + cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); + + // swiglu + { + int64_t split_point = cur->ne[0] / 2; + ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0)); + ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); + + // see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half + x1 = ggml_silu(ctx0, x1); + cur = ggml_mul(ctx0, x0, x1); + } + + // mid-norm + cur = ggml_rms_norm(ctx0, cur, 1e-6); + cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w); + + // ffn out + cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); + } + + } else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) { + // projector + cur = ggml_mul_mat(ctx0, model.mm_fc_w, cur); + cur = ggml_add(ctx0, cur, model.mm_fc_b); + + } else { + GGML_ABORT("%s: unknown projector type", __func__); + } + + cb(cur, "projected", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + private: // // utility functions @@ -1541,6 +1669,17 @@ private: inpL = cur; } + // TODO @ngxson : find a way to move this outside + if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) { + ggml_tensor * cur = inpL; + cur = ggml_transpose(ctx0, cur); + cur = ggml_cont(ctx0, cur); + cur = ggml_pool_1d(ctx0, cur, GGML_OP_POOL_AVG, 2, 2, 0); + cur = ggml_transpose(ctx0, cur); + cur = ggml_cont(ctx0, cur); + inpL = cur; + } + // post-layernorm if (model.post_ln_w) { inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, -1); @@ -1562,8 +1701,8 @@ private: return inp; } - ggml_tensor * build_inp_raw() { - ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, 3); + ggml_tensor * build_inp_raw(int channels = 3) { + ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, channels); ggml_set_name(inp_raw, "inp_raw"); ggml_set_input(inp_raw); return inp_raw; @@ -1641,6 +1780,11 @@ private: cur = ggml_gelu(ctx0, cur); cb(cur, "ffn_gelu", il); } break; + case FFN_GELU_ERF: + { + cur = ggml_gelu_erf(ctx0, cur); + cb(cur, "ggml_gelu_erf", il); + } break; case FFN_GELU_QUICK: { cur = ggml_gelu_quick(ctx0, cur); @@ -1805,7 +1949,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 ggml_cgraph * res; - switch (ctx->proj_type) { + switch (ctx->proj_type()) { case PROJECTOR_TYPE_GEMMA3: case PROJECTOR_TYPE_IDEFICS3: { @@ -1832,6 +1976,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { res = graph.build_llama4(); } break; + case PROJECTOR_TYPE_ULTRAVOX: + case PROJECTOR_TYPE_QWEN2A: + { + res = graph.build_whisper_enc(); + } break; default: { res = graph.build_llava(); @@ -1844,13 +1993,15 @@ struct clip_model_loader { ggml_context_ptr ctx_meta; gguf_context_ptr ctx_gguf; - clip_ctx & ctx_clip; std::string fname; size_t model_size = 0; // in bytes - // TODO @ngxson : we should not pass clip_ctx here, it should be clip_vision_model - clip_model_loader(const char * fname, clip_ctx & ctx_clip) : ctx_clip(ctx_clip), fname(fname) { + bool has_vision = false; + bool has_audio = false; + + // TODO @ngxson : we should not pass clip_ctx here, it should be clip_model + clip_model_loader(const char * fname) : fname(fname) { struct ggml_context * meta = nullptr; struct gguf_init_params params = { @@ -1882,6 +2033,19 @@ struct clip_model_loader { LOG_INF("\n"); } + // modalities + { + get_bool(KEY_HAS_VISION_ENC, has_vision, false); + get_bool(KEY_HAS_AUDIO_ENC, has_audio, false); + + if (has_vision) { + LOG_INF("%s: has vision encoder\n", __func__); + } + if (has_audio) { + LOG_INF("%s: has audio encoder\n", __func__); + } + } + // tensors { for (int i = 0; i < n_tensors; ++i) { @@ -1897,44 +2061,72 @@ struct clip_model_loader { } } - void load_hparams() { - auto & hparams = ctx_clip.vision_model.hparams; + void load_hparams(clip_model & model, clip_modality modality) { + auto & hparams = model.hparams; std::string log_ffn_op; // for logging + // sanity check + if (modality == CLIP_MODALITY_VISION) { + GGML_ASSERT(has_vision); + } else if (modality == CLIP_MODALITY_AUDIO) { + GGML_ASSERT(has_audio); + } + model.modality = modality; + + // projector type std::string proj_type; { get_string(KEY_PROJ_TYPE, proj_type, false); if (!proj_type.empty()) { - ctx_clip.proj_type = clip_projector_type_from_string(proj_type); + model.proj_type = clip_projector_type_from_string(proj_type); } - if (ctx_clip.proj_type == PROJECTOR_TYPE_UNKNOWN) { + if (model.proj_type == PROJECTOR_TYPE_UNKNOWN) { throw std::runtime_error(string_format("%s: unknown projector type: %s\n", __func__, proj_type.c_str())); } + + // correct arch for multimodal models + if (model.proj_type == PROJECTOR_TYPE_QWEN25O) { + model.proj_type = modality == CLIP_MODALITY_VISION + ? PROJECTOR_TYPE_QWEN25VL + : PROJECTOR_TYPE_QWEN2A; + } } + const bool is_vision = model.modality == CLIP_MODALITY_VISION; + const bool is_audio = model.modality == CLIP_MODALITY_AUDIO; + // other hparams { - get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false); // legacy + const char * prefix = is_vision ? "vision" : "audio"; + get_u32(string_format(KEY_N_EMBD, prefix), hparams.n_embd); + get_u32(string_format(KEY_N_HEAD, prefix), hparams.n_head); + get_u32(string_format(KEY_N_FF, prefix), hparams.n_ff); + get_u32(string_format(KEY_N_BLOCK, prefix), hparams.n_layer); + get_u32(string_format(KEY_PROJ_DIM, prefix), hparams.projection_dim); + get_f32(string_format(KEY_LAYER_NORM_EPS, prefix), hparams.eps); - get_u32(KEY_N_EMBD, hparams.n_embd); - get_u32(KEY_N_HEAD, hparams.n_head); - get_u32(KEY_N_FF, hparams.n_ff); - get_u32(KEY_N_BLOCK, hparams.n_layer); - get_u32(KEY_PROJ_DIM, hparams.projection_dim); - get_f32(KEY_LAYER_NORM_EPS, hparams.eps); - get_u32(KEY_IMAGE_SIZE, hparams.image_size); - get_u32(KEY_PATCH_SIZE, hparams.patch_size); - get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false); - get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false); + if (is_vision) { + get_u32(KEY_IMAGE_SIZE, hparams.image_size); + get_u32(KEY_PATCH_SIZE, hparams.patch_size); + get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false); + get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false); + get_i32(KEY_MINICPMV_VERSION, hparams.minicpmv_version, false); // legacy + + } else if (is_audio) { + get_u32(KEY_A_NUM_MEL_BINS, hparams.n_mel_bins); + + } else { + GGML_ASSERT(false && "unknown modality"); + } // default warmup value hparams.warmup_image_size = hparams.image_size; - ctx_clip.has_llava_projector = ctx_clip.proj_type == PROJECTOR_TYPE_MLP - || ctx_clip.proj_type == PROJECTOR_TYPE_MLP_NORM - || ctx_clip.proj_type == PROJECTOR_TYPE_LDP - || ctx_clip.proj_type == PROJECTOR_TYPE_LDPV2; + hparams.has_llava_projector = model.proj_type == PROJECTOR_TYPE_MLP + || model.proj_type == PROJECTOR_TYPE_MLP_NORM + || model.proj_type == PROJECTOR_TYPE_LDP + || model.proj_type == PROJECTOR_TYPE_LDPV2; { bool use_gelu = false; @@ -1964,7 +2156,7 @@ struct clip_model_loader { } } - { + if (is_vision) { int idx_mean = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_MEAN); int idx_std = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_STD); GGML_ASSERT(idx_mean >= 0 && "image_mean not found"); @@ -1972,8 +2164,8 @@ struct clip_model_loader { const float * mean_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_mean); const float * std_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_std); for (int i = 0; i < 3; ++i) { - ctx_clip.image_mean[i] = mean_data[i]; - ctx_clip.image_std[i] = std_data[i]; + hparams.image_mean[i] = mean_data[i]; + hparams.image_std[i] = std_data[i]; } } @@ -1990,11 +2182,11 @@ struct clip_model_loader { } // model-specific params - switch (ctx_clip.proj_type) { + switch (model.proj_type) { case PROJECTOR_TYPE_MINICPMV: { - if (ctx_clip.minicpmv_version == 0) { - ctx_clip.minicpmv_version = 2; // default to 2 if not set + if (hparams.minicpmv_version == 0) { + hparams.minicpmv_version = 2; // default to 2 if not set } } break; case PROJECTOR_TYPE_IDEFICS3: @@ -2050,6 +2242,17 @@ struct clip_model_loader { isize, isize*3, // 336, 1008 }; } break; + case PROJECTOR_TYPE_ULTRAVOX: + case PROJECTOR_TYPE_QWEN2A: + { + bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX; + get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor, require_stack); + if (hparams.n_mel_bins != 128) { + throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__)); + } + hparams.ffn_op = FFN_GELU_ERF; + log_ffn_op = "gelu_erf"; // temporary solution for logging + } break; default: break; } @@ -2059,29 +2262,36 @@ struct clip_model_loader { LOG_INF("%s: n_head: %d\n", __func__, hparams.n_head); LOG_INF("%s: n_ff: %d\n", __func__, hparams.n_ff); LOG_INF("%s: n_layer: %d\n", __func__, hparams.n_layer); - LOG_INF("%s: projection_dim: %d\n", __func__, hparams.projection_dim); - LOG_INF("%s: image_size: %d\n", __func__, hparams.image_size); - LOG_INF("%s: patch_size: %d\n", __func__, hparams.patch_size); - LOG_INF("\n"); - LOG_INF("%s: has_llava_proj: %d\n", __func__, ctx_clip.has_llava_projector); - LOG_INF("%s: minicpmv_version: %d\n", __func__, ctx_clip.minicpmv_version); - LOG_INF("%s: proj_scale_factor: %d\n", __func__, hparams.proj_scale_factor); - LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern); LOG_INF("%s: ffn_op: %s\n", __func__, log_ffn_op.c_str()); + LOG_INF("%s: projection_dim: %d\n", __func__, hparams.projection_dim); + if (is_vision) { + LOG_INF("\n--- vision hparams ---\n"); + LOG_INF("%s: image_size: %d\n", __func__, hparams.image_size); + LOG_INF("%s: patch_size: %d\n", __func__, hparams.patch_size); + LOG_INF("%s: has_llava_proj: %d\n", __func__, hparams.has_llava_projector); + LOG_INF("%s: minicpmv_version: %d\n", __func__, hparams.minicpmv_version); + LOG_INF("%s: proj_scale_factor: %d\n", __func__, hparams.proj_scale_factor); + LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern); + } else if (is_audio) { + LOG_INF("\n--- audio hparams ---\n"); + LOG_INF("%s: n_mel_bins: %d\n", __func__, hparams.n_mel_bins); + LOG_INF("%s: proj_stack_factor: %d\n", __func__, hparams.proj_stack_factor); + } + LOG_INF("\n"); LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0); LOG_INF("%s: metadata size: %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0); - - if (ctx_clip.proj_type == PROJECTOR_TYPE_LLAMA4) { - LOG_WRN("%s: llama 4 vision is known to have degraded quality: https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__); - } } } - void load_tensors() { - auto & hparams = ctx_clip.vision_model.hparams; + void load_tensors(clip_ctx & ctx_clip) { + auto & model = ctx_clip.model; + auto & hparams = model.hparams; std::map tensor_offset; std::vector tensors_to_load; + // TODO @ngxson : support both audio and video in the future + const char * prefix = model.modality == CLIP_MODALITY_AUDIO ? "a" : "v"; + // get offsets for (int64_t i = 0; i < gguf_get_n_tensors(ctx_gguf.get()); ++i) { const char * name = gguf_get_tensor_name(ctx_gguf.get(), i); @@ -2115,51 +2325,49 @@ struct clip_model_loader { return cur; }; - auto & vision_model = ctx_clip.vision_model; + model.class_embedding = get_tensor(TN_CLASS_EMBD, false); - vision_model.class_embedding = get_tensor(TN_CLASS_EMBD, false); + model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, prefix, "weight"), false); + model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, prefix, "bias"), false); - vision_model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, "v", "weight"), false); - vision_model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, "v", "bias"), false); + model.post_ln_w = get_tensor(string_format(TN_LN_POST, prefix, "weight"), false); + model.post_ln_b = get_tensor(string_format(TN_LN_POST, prefix, "bias"), false); - vision_model.post_ln_w = get_tensor(string_format(TN_LN_POST, "v", "weight"), false); - vision_model.post_ln_b = get_tensor(string_format(TN_LN_POST, "v", "bias"), false); + model.patch_bias = get_tensor(TN_PATCH_BIAS, false); + model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false); + model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false); - vision_model.patch_bias = get_tensor(TN_PATCH_BIAS, false); - vision_model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false); - vision_model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false); - - vision_model.position_embeddings = get_tensor(TN_POS_EMBD, false); + model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, prefix), false); // layers - vision_model.layers.resize(hparams.n_layer); + model.layers.resize(hparams.n_layer); for (int il = 0; il < hparams.n_layer; ++il) { - auto & layer = vision_model.layers[il]; - layer.k_w = get_tensor(string_format(TN_ATTN_K, "v", il, "weight")); - layer.q_w = get_tensor(string_format(TN_ATTN_Q, "v", il, "weight")); - layer.v_w = get_tensor(string_format(TN_ATTN_V, "v", il, "weight")); - layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "weight")); - layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, "v", il, "weight"), false); - layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, "v", il, "weight"), false); - layer.ln_1_w = get_tensor(string_format(TN_LN_1, "v", il, "weight"), false); - layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false); - layer.ls_1_w = get_tensor(string_format(TN_LS_1, "v", il, "weight"), false); // no bias - layer.ls_2_w = get_tensor(string_format(TN_LS_2, "v", il, "weight"), false); // no bias + auto & layer = model.layers[il]; + layer.k_w = get_tensor(string_format(TN_ATTN_K, prefix, il, "weight")); + layer.q_w = get_tensor(string_format(TN_ATTN_Q, prefix, il, "weight")); + layer.v_w = get_tensor(string_format(TN_ATTN_V, prefix, il, "weight")); + layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "weight")); + layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, prefix, il, "weight"), false); + layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, prefix, il, "weight"), false); + layer.ln_1_w = get_tensor(string_format(TN_LN_1, prefix, il, "weight"), false); + layer.ln_2_w = get_tensor(string_format(TN_LN_2, prefix, il, "weight"), false); + layer.ls_1_w = get_tensor(string_format(TN_LS_1, prefix, il, "weight"), false); // no bias + layer.ls_2_w = get_tensor(string_format(TN_LS_2, prefix, il, "weight"), false); // no bias - layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false); - layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false); - layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false); - layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "bias"), false); - layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false); - layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false); + layer.k_b = get_tensor(string_format(TN_ATTN_K, prefix, il, "bias"), false); + layer.q_b = get_tensor(string_format(TN_ATTN_Q, prefix, il, "bias"), false); + layer.v_b = get_tensor(string_format(TN_ATTN_V, prefix, il, "bias"), false); + layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "bias"), false); + layer.ln_1_b = get_tensor(string_format(TN_LN_1, prefix, il, "bias"), false); + layer.ln_2_b = get_tensor(string_format(TN_LN_2, prefix, il, "bias"), false); // ffn - layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight")); - layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false); - layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), false); - layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, "v", il, "bias"), false); - layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight")); - layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false); + layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, prefix, il, "weight")); + layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, prefix, il, "bias"), false); + layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, prefix, il, "weight"), false); + layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, prefix, il, "bias"), false); + layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "weight")); + layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "bias"), false); // some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here // note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check! @@ -2175,146 +2383,166 @@ struct clip_model_loader { } } - switch (ctx_clip.proj_type) { + switch (model.proj_type) { case PROJECTOR_TYPE_MLP: case PROJECTOR_TYPE_MLP_NORM: { // LLaVA projection - vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"), false); - vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"), false); + model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"), false); + model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"), false); // Yi-type llava - vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"), false); - vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false); + model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"), false); + model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false); // missing in Yi-type llava - vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"), false); - vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false); + model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"), false); + model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false); // Yi-type llava - vision_model.mm_3_w = get_tensor(string_format(TN_LLAVA_PROJ, 3, "weight"), false); - vision_model.mm_3_b = get_tensor(string_format(TN_LLAVA_PROJ, 3, "bias"), false); - vision_model.mm_4_w = get_tensor(string_format(TN_LLAVA_PROJ, 4, "weight"), false); - vision_model.mm_4_b = get_tensor(string_format(TN_LLAVA_PROJ, 4, "bias"), false); - if (vision_model.mm_3_w) { + model.mm_3_w = get_tensor(string_format(TN_LLAVA_PROJ, 3, "weight"), false); + model.mm_3_b = get_tensor(string_format(TN_LLAVA_PROJ, 3, "bias"), false); + model.mm_4_w = get_tensor(string_format(TN_LLAVA_PROJ, 4, "weight"), false); + model.mm_4_b = get_tensor(string_format(TN_LLAVA_PROJ, 4, "bias"), false); + if (model.mm_3_w) { // TODO: this is a hack to support Yi-type llava - ctx_clip.proj_type = PROJECTOR_TYPE_MLP_NORM; + model.proj_type = PROJECTOR_TYPE_MLP_NORM; } - vision_model.image_newline = get_tensor(TN_IMAGE_NEWLINE, false); + model.image_newline = get_tensor(TN_IMAGE_NEWLINE, false); } break; case PROJECTOR_TYPE_LDP: { // MobileVLM projection - vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); - vision_model.mm_model_mlp_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias")); - vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight")); - vision_model.mm_model_mlp_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias")); - vision_model.mm_model_block_1_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "0.weight")); - vision_model.mm_model_block_1_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.weight")); - vision_model.mm_model_block_1_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.bias")); - vision_model.mm_model_block_1_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.weight")); - vision_model.mm_model_block_1_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.bias")); - vision_model.mm_model_block_1_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.weight")); - vision_model.mm_model_block_1_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.bias")); - vision_model.mm_model_block_1_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "0.weight")); - vision_model.mm_model_block_1_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.weight")); - vision_model.mm_model_block_1_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.bias")); - vision_model.mm_model_block_2_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "0.weight")); - vision_model.mm_model_block_2_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.weight")); - vision_model.mm_model_block_2_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.bias")); - vision_model.mm_model_block_2_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.weight")); - vision_model.mm_model_block_2_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.bias")); - vision_model.mm_model_block_2_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.weight")); - vision_model.mm_model_block_2_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.bias")); - vision_model.mm_model_block_2_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "0.weight")); - vision_model.mm_model_block_2_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.weight")); - vision_model.mm_model_block_2_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.bias")); + model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); + model.mm_model_mlp_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias")); + model.mm_model_mlp_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight")); + model.mm_model_mlp_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias")); + model.mm_model_block_1_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "0.weight")); + model.mm_model_block_1_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.weight")); + model.mm_model_block_1_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.bias")); + model.mm_model_block_1_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.weight")); + model.mm_model_block_1_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.bias")); + model.mm_model_block_1_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.weight")); + model.mm_model_block_1_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.bias")); + model.mm_model_block_1_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "0.weight")); + model.mm_model_block_1_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.weight")); + model.mm_model_block_1_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.bias")); + model.mm_model_block_2_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "0.weight")); + model.mm_model_block_2_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.weight")); + model.mm_model_block_2_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.bias")); + model.mm_model_block_2_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.weight")); + model.mm_model_block_2_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.bias")); + model.mm_model_block_2_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.weight")); + model.mm_model_block_2_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.bias")); + model.mm_model_block_2_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "0.weight")); + model.mm_model_block_2_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.weight")); + model.mm_model_block_2_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.bias")); } break; case PROJECTOR_TYPE_LDPV2: { // MobilVLM_V2 projection - vision_model.mm_model_mlp_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight")); - vision_model.mm_model_mlp_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias")); - vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight")); - vision_model.mm_model_mlp_2_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "bias")); - vision_model.mm_model_peg_0_w = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "weight")); - vision_model.mm_model_peg_0_b = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "bias")); + model.mm_model_mlp_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight")); + model.mm_model_mlp_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias")); + model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight")); + model.mm_model_mlp_2_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "bias")); + model.mm_model_peg_0_w = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "weight")); + model.mm_model_peg_0_b = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "bias")); } break; case PROJECTOR_TYPE_MINICPMV: { - // vision_model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD); - vision_model.mm_model_pos_embed_k = get_tensor(TN_MINICPMV_POS_EMBD_K); - vision_model.mm_model_query = get_tensor(TN_MINICPMV_QUERY); - vision_model.mm_model_proj = get_tensor(TN_MINICPMV_PROJ); - vision_model.mm_model_kv_proj = get_tensor(TN_MINICPMV_KV_PROJ); - vision_model.mm_model_attn_q_w = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "weight")); - vision_model.mm_model_attn_k_w = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "weight")); - vision_model.mm_model_attn_v_w = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "weight")); - vision_model.mm_model_attn_q_b = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "bias")); - vision_model.mm_model_attn_k_b = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "bias")); - vision_model.mm_model_attn_v_b = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "bias")); - vision_model.mm_model_attn_o_w = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "weight")); - vision_model.mm_model_attn_o_b = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "bias")); - vision_model.mm_model_ln_q_w = get_tensor(string_format(TN_MINICPMV_LN, "q", "weight")); - vision_model.mm_model_ln_q_b = get_tensor(string_format(TN_MINICPMV_LN, "q", "bias")); - vision_model.mm_model_ln_kv_w = get_tensor(string_format(TN_MINICPMV_LN, "kv", "weight")); - vision_model.mm_model_ln_kv_b = get_tensor(string_format(TN_MINICPMV_LN, "kv", "bias")); - vision_model.mm_model_ln_post_w = get_tensor(string_format(TN_MINICPMV_LN, "post", "weight")); - vision_model.mm_model_ln_post_b = get_tensor(string_format(TN_MINICPMV_LN, "post", "bias")); + // model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD); + model.mm_model_pos_embed_k = get_tensor(TN_MINICPMV_POS_EMBD_K); + model.mm_model_query = get_tensor(TN_MINICPMV_QUERY); + model.mm_model_proj = get_tensor(TN_MINICPMV_PROJ); + model.mm_model_kv_proj = get_tensor(TN_MINICPMV_KV_PROJ); + model.mm_model_attn_q_w = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "weight")); + model.mm_model_attn_k_w = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "weight")); + model.mm_model_attn_v_w = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "weight")); + model.mm_model_attn_q_b = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "bias")); + model.mm_model_attn_k_b = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "bias")); + model.mm_model_attn_v_b = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "bias")); + model.mm_model_attn_o_w = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "weight")); + model.mm_model_attn_o_b = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "bias")); + model.mm_model_ln_q_w = get_tensor(string_format(TN_MINICPMV_LN, "q", "weight")); + model.mm_model_ln_q_b = get_tensor(string_format(TN_MINICPMV_LN, "q", "bias")); + model.mm_model_ln_kv_w = get_tensor(string_format(TN_MINICPMV_LN, "kv", "weight")); + model.mm_model_ln_kv_b = get_tensor(string_format(TN_MINICPMV_LN, "kv", "bias")); + model.mm_model_ln_post_w = get_tensor(string_format(TN_MINICPMV_LN, "post", "weight")); + model.mm_model_ln_post_b = get_tensor(string_format(TN_MINICPMV_LN, "post", "bias")); } break; case PROJECTOR_TYPE_GLM_EDGE: { - vision_model.mm_model_adapter_conv_w = get_tensor(string_format(TN_GLM_ADAPER_CONV, "weight")); - vision_model.mm_model_adapter_conv_b = get_tensor(string_format(TN_GLM_ADAPER_CONV, "bias")); - vision_model.mm_model_mlp_0_w = get_tensor(string_format(TN_GLM_ADAPTER_LINEAR, "weight")); - vision_model.mm_model_ln_q_w = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1, "weight")); - vision_model.mm_model_ln_q_b = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1, "bias")); - vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_GLM_ADAPTER_D_H_2_4H, "weight")); - vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE, "weight")); - vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H, "weight")); - vision_model.mm_glm_tok_boi = get_tensor(string_format(TN_TOK_GLM_BOI, "weight")); - vision_model.mm_glm_tok_eoi = get_tensor(string_format(TN_TOK_GLM_EOI, "weight")); + model.mm_model_adapter_conv_w = get_tensor(string_format(TN_GLM_ADAPER_CONV, "weight")); + model.mm_model_adapter_conv_b = get_tensor(string_format(TN_GLM_ADAPER_CONV, "bias")); + model.mm_model_mlp_0_w = get_tensor(string_format(TN_GLM_ADAPTER_LINEAR, "weight")); + model.mm_model_ln_q_w = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1, "weight")); + model.mm_model_ln_q_b = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1, "bias")); + model.mm_model_mlp_1_w = get_tensor(string_format(TN_GLM_ADAPTER_D_H_2_4H, "weight")); + model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE, "weight")); + model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H, "weight")); + model.mm_glm_tok_boi = get_tensor(string_format(TN_TOK_GLM_BOI, "weight")); + model.mm_glm_tok_eoi = get_tensor(string_format(TN_TOK_GLM_EOI, "weight")); } break; case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: { - vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); - vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias")); - vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); - vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); + model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); + model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias")); + model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); + model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); } break; case PROJECTOR_TYPE_GEMMA3: { - vision_model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ); - vision_model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N); + model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ); + model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N); } break; case PROJECTOR_TYPE_IDEFICS3: { - vision_model.projection = get_tensor(TN_MM_PROJECTOR); + model.projection = get_tensor(TN_MM_PROJECTOR); } break; case PROJECTOR_TYPE_PIXTRAL: { - vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); - vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false); - vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); - vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false); + model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); + model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false); + model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); + model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false); // [IMG_BREAK] token embedding - vision_model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK); + model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK); // for mistral small 3.1 - vision_model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); - vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false); + model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); + model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false); + } break; + case PROJECTOR_TYPE_ULTRAVOX: + { + model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight")); + model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias")); + model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight")); + model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias")); + model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight")); + model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight")); + model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight")); + model.mm_norm_mid_w = get_tensor(string_format(TN_MM_NORM_MID, "weight")); + } break; + case PROJECTOR_TYPE_QWEN2A: + { + model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight")); + model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias")); + model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight")); + model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias")); + model.mm_fc_w = get_tensor(string_format(TN_MM_AUDIO_FC, "weight")); + model.mm_fc_b = get_tensor(string_format(TN_MM_AUDIO_FC, "bias")); } break; case PROJECTOR_TYPE_INTERNVL: { - vision_model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight")); - vision_model.mm_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias")); - vision_model.mm_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); - vision_model.mm_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias")); - vision_model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight")); - vision_model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias")); + model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight")); + model.mm_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias")); + model.mm_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); + model.mm_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias")); + model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight")); + model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias")); } break; case PROJECTOR_TYPE_LLAMA4: { - vision_model.mm_model_proj = get_tensor(TN_MM_PROJECTOR); - vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); - vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight")); + model.mm_model_proj = get_tensor(TN_MM_PROJECTOR); + model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); + model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight")); } break; default: GGML_ASSERT(false && "unknown projector type"); @@ -2357,15 +2585,20 @@ struct clip_model_loader { } } - void alloc_compute_meta() { + void alloc_compute_meta(clip_ctx & ctx_clip) { + const auto & hparams = ctx_clip.model.hparams; ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead()); // create a fake batch clip_image_f32_batch batch; clip_image_f32_ptr img(clip_image_f32_init()); - img->nx = ctx_clip.vision_model.hparams.warmup_image_size; - img->ny = ctx_clip.vision_model.hparams.warmup_image_size; - img->buf.resize(img->nx * img->ny * 3); + if (ctx_clip.model.modality == CLIP_MODALITY_VISION) { + img->nx = hparams.warmup_image_size; + img->ny = hparams.warmup_image_size; + } else { + img->nx = hparams.warmup_audio_size; + img->ny = hparams.n_mel_bins; + } batch.entries.push_back(std::move(img)); ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch); @@ -2443,23 +2676,40 @@ struct clip_model_loader { } }; -struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_params) { +struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_params) { g_logger_state.verbosity_thold = ctx_params.verbosity; - clip_ctx * ctx_clip = nullptr; + clip_ctx * ctx_vision = nullptr; + clip_ctx * ctx_audio = nullptr; try { - ctx_clip = new clip_ctx(ctx_params); - clip_model_loader loader(fname, *ctx_clip); - loader.load_hparams(); - loader.load_tensors(); - loader.alloc_compute_meta(); + clip_model_loader loader(fname); + + if (loader.has_vision) { + ctx_vision = new clip_ctx(ctx_params); + loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION); + loader.load_tensors(*ctx_vision); + loader.alloc_compute_meta(*ctx_vision); + } + + if (loader.has_audio) { + ctx_audio = new clip_ctx(ctx_params); + loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO); + loader.load_tensors(*ctx_audio); + loader.alloc_compute_meta(*ctx_audio); + } + } catch (const std::exception & e) { LOG_ERR("%s: failed to load model '%s': %s\n", __func__, fname, e.what()); - delete ctx_clip; - return nullptr; + if (ctx_vision) { + delete ctx_vision; + } + if (ctx_audio) { + delete ctx_audio; + } + return {nullptr, nullptr}; } - return ctx_clip; + return {ctx_vision, ctx_audio}; } struct clip_image_size * clip_image_size_init() { @@ -2533,30 +2783,6 @@ void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny memcpy(img->buf.data(), rgb_pixels, img->buf.size()); } -bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) { - int nx, ny, nc; - auto * data = stbi_load(fname, &nx, &ny, &nc, 3); - if (!data) { - LOG_ERR("%s: failed to load image '%s'\n", __func__, fname); - return false; - } - clip_build_img_from_pixels(data, nx, ny, img); - stbi_image_free(data); - return true; -} - -bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img) { - int nx, ny, nc; - auto * data = stbi_load_from_memory(bytes, bytes_length, &nx, &ny, &nc, 3); - if (!data) { - LOG_ERR("%s: failed to decode image bytes\n", __func__); - return false; - } - clip_build_img_from_pixels(data, nx, ny, img); - stbi_image_free(data); - return true; -} - // Normalize image to float32 - careful with pytorch .to(model.device, dtype=torch.float16) - this sometimes reduces precision (32>16>32), sometimes not static void normalize_image_u8_to_f32(const clip_image_u8 & src, clip_image_f32 & dst, const float mean[3], const float std[3]) { dst.nx = src.nx; @@ -2820,12 +3046,12 @@ struct llava_uhd { const float ratio = (float)original_width * original_height / (slice_size * slice_size); const int multiple = fmin(ceil(ratio), max_slice_nums); const bool has_slices = (multiple > 1); - const bool has_pinpoints = !ctx->vision_model.hparams.image_grid_pinpoints.empty(); + const bool has_pinpoints = !ctx->model.hparams.image_grid_pinpoints.empty(); if (has_pinpoints) { // has pinpoints, use them to calculate the grid size (e.g. llava-1.6) auto refine_size = llava_uhd::select_best_resolution( - ctx->vision_model.hparams.image_grid_pinpoints, + ctx->model.hparams.image_grid_pinpoints, original_size); res.overview_size = clip_image_size{slice_size, slice_size}; res.refined_size = refine_size; @@ -3047,7 +3273,7 @@ private: bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) { clip_image_size original_size{img->nx, img->ny}; bool pad_to_square = true; - auto & params = ctx->vision_model.hparams; + auto & params = ctx->model.hparams; // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing if (params.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD) { pad_to_square = false; @@ -3060,7 +3286,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str for (size_t i = 0; i < imgs.size(); ++i) { // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp"); clip_image_f32_ptr res(clip_image_f32_init()); - normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std); + normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(res)); } @@ -3068,7 +3294,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->grid_y = inst.grid_size.height; return true; - } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) { + } else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) { clip_image_u8 resized; auto patch_size = params.patch_size * 2; auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, patch_size, params.image_size); @@ -3076,42 +3302,42 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str clip_image_f32_ptr img_f32(clip_image_f32_init()); // clip_image_f32_ptr res(clip_image_f32_init()); - normalize_image_u8_to_f32(resized, *img_f32, ctx->image_mean, ctx->image_std); + normalize_image_u8_to_f32(resized, *img_f32, params.image_mean, params.image_std); // res_imgs->data[0] = *res; res_imgs->entries.push_back(std::move(img_f32)); return true; } - else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE - || ctx->proj_type == PROJECTOR_TYPE_GEMMA3 - || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3 - || ctx->proj_type == PROJECTOR_TYPE_INTERNVL // TODO @ngxson : support dynamic resolution + else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE + || ctx->proj_type() == PROJECTOR_TYPE_GEMMA3 + || ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3 + || ctx->proj_type() == PROJECTOR_TYPE_INTERNVL // TODO @ngxson : support dynamic resolution ) { clip_image_u8 resized_image; int sz = params.image_size; image_manipulation::resize_and_pad_image(*img, resized_image, {sz, sz}); clip_image_f32_ptr img_f32(clip_image_f32_init()); //clip_image_save_to_bmp(resized_image, "resized.bmp"); - normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std); + normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(img_f32)); return true; - } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) { + } else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL) { clip_image_u8 resized_image; auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size); image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height); clip_image_f32_ptr img_f32(clip_image_f32_init()); - normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std); + normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(img_f32)); return true; - } else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) { + } else if (ctx->proj_type() == PROJECTOR_TYPE_LLAMA4) { GGML_ASSERT(!params.image_grid_pinpoints.empty()); auto const inst = llava_uhd::get_slice_instructions(ctx, original_size); std::vector imgs = llava_uhd::slice_image(img, inst); for (size_t i = 0; i < imgs.size(); ++i) { clip_image_f32_ptr res(clip_image_f32_init()); - normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std); + normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(res)); } @@ -3141,7 +3367,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str image_manipulation::resize_and_pad_image(*img, *temp, clip_image_size{params.image_size, params.image_size}, pad_color); clip_image_f32_ptr res(clip_image_f32_init()); - normalize_image_u8_to_f32(*temp, *res, ctx->image_mean, ctx->image_std); + normalize_image_u8_to_f32(*temp, *res, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(res)); return true; @@ -3153,7 +3379,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str for (size_t i = 0; i < imgs.size(); ++i) { // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp"); clip_image_f32_ptr res(clip_image_f32_init()); - normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std); + normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(res)); } @@ -3165,7 +3391,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str } ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx) { - return ctx->vision_model.image_newline; + return ctx->model.image_newline; } void clip_free(clip_ctx * ctx) { @@ -3177,8 +3403,8 @@ void clip_free(clip_ctx * ctx) { // deprecated size_t clip_embd_nbytes(const struct clip_ctx * ctx) { - const int32_t nx = ctx->vision_model.hparams.image_size; - const int32_t ny = ctx->vision_model.hparams.image_size; + const int32_t nx = ctx->model.hparams.image_size; + const int32_t ny = ctx->model.hparams.image_size; return clip_embd_nbytes_by_img(ctx, nx, ny); } @@ -3190,97 +3416,135 @@ size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_w, int img_h } int32_t clip_get_image_size(const struct clip_ctx * ctx) { - return ctx->vision_model.hparams.image_size; + return ctx->model.hparams.image_size; } int32_t clip_get_patch_size(const struct clip_ctx * ctx) { - return ctx->vision_model.hparams.patch_size; + return ctx->model.hparams.patch_size; } int32_t clip_get_hidden_size(const struct clip_ctx * ctx) { - return ctx->vision_model.hparams.n_embd; + return ctx->model.hparams.n_embd; } const char * clip_patch_merge_type(const struct clip_ctx * ctx) { - return ctx->vision_model.hparams.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD ? "spatial_unpad" : "flat"; + return ctx->model.hparams.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD ? "spatial_unpad" : "flat"; } const int32_t * clip_image_grid(const struct clip_ctx * ctx) { - if (ctx->vision_model.hparams.image_grid_pinpoints.size()) { - return &ctx->vision_model.hparams.image_grid_pinpoints.front(); + if (ctx->model.hparams.image_grid_pinpoints.size()) { + return &ctx->model.hparams.image_grid_pinpoints.front(); } return nullptr; } size_t get_clip_image_grid_size(const struct clip_ctx * ctx) { - return ctx->vision_model.hparams.image_grid_pinpoints.size(); + return ctx->model.hparams.image_grid_pinpoints.size(); } int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img) { - const auto & params = ctx->vision_model.hparams; + const auto & params = ctx->model.hparams; const int n_total = clip_n_output_tokens(ctx, img); - if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) { + if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) { return img->nx / (params.patch_size * 2) + (int)(img->nx % params.patch_size > 0); } return n_total; } int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img) { - const auto & params = ctx->vision_model.hparams; - if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) { + const auto & params = ctx->model.hparams; + if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) { return img->ny / (params.patch_size * 2) + (int)(img->ny % params.patch_size > 0); } return 1; } int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img) { - const auto & params = ctx->vision_model.hparams; + const auto & params = ctx->model.hparams; - int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size); - int scale_factor = ctx->vision_model.hparams.proj_scale_factor; + // only for models using fixed size square images + int n_patches_sq = (params.image_size / params.patch_size) * (params.image_size / params.patch_size); - if (ctx->proj_type == PROJECTOR_TYPE_LDP - || ctx->proj_type == PROJECTOR_TYPE_LDPV2 - || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { - n_patches /= 4; - if (ctx->vision_model.mm_glm_tok_boi) { - n_patches += 2; // for BOI and EOI token embeddings - } - } else if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) { - if (ctx->minicpmv_version == 2) { - n_patches = 96; - } - else if (ctx->minicpmv_version == 3) { - n_patches = 64; - } - else if (ctx->minicpmv_version == 4) { - n_patches = 64; - } - else { - GGML_ABORT("Unknown minicpmv version"); - } - } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) { - int patch_size = params.patch_size * 2; - int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0); - int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0); - n_patches = x_patch * y_patch; - } else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { - int n_per_side = params.image_size / params.patch_size; - int n_per_side_2d_pool = n_per_side / params.proj_scale_factor; - n_patches = n_per_side_2d_pool * n_per_side_2d_pool; - } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3 || ctx->proj_type == PROJECTOR_TYPE_INTERNVL) { - // both W and H are divided by proj_scale_factor - n_patches /= (params.proj_scale_factor * params.proj_scale_factor); - } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) { - int n_merge = params.spatial_merge_size; - int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1); - int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1); - n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row - } else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) { - n_patches /= (scale_factor * scale_factor); + projector_type proj = ctx->proj_type(); + + switch (proj) { + case PROJECTOR_TYPE_MLP: + case PROJECTOR_TYPE_MLP_NORM: + { + // do nothing + } break; + case PROJECTOR_TYPE_LDP: + case PROJECTOR_TYPE_LDPV2: + case PROJECTOR_TYPE_GLM_EDGE: + { + n_patches_sq /= 4; + if (ctx->model.mm_glm_tok_boi) { + n_patches_sq += 2; // for BOI and EOI token embeddings + } + } break; + case PROJECTOR_TYPE_MINICPMV: + { + if (params.minicpmv_version == 2) { + n_patches_sq = 96; + } else if (params.minicpmv_version == 3) { + n_patches_sq = 64; + } else if (params.minicpmv_version == 4) { + n_patches_sq = 64; + } else { + GGML_ABORT("Unknown minicpmv version"); + } + } break; + case PROJECTOR_TYPE_QWEN2VL: + case PROJECTOR_TYPE_QWEN25VL: + { + // dynamic size + int patch_size = params.patch_size * 2; + int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0); + int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0); + n_patches_sq = x_patch * y_patch; + } break; + case PROJECTOR_TYPE_GEMMA3: + { + int n_per_side = params.image_size / params.patch_size; + int n_per_side_2d_pool = n_per_side / params.proj_scale_factor; + n_patches_sq = n_per_side_2d_pool * n_per_side_2d_pool; + } break; + case PROJECTOR_TYPE_IDEFICS3: + case PROJECTOR_TYPE_INTERNVL: + { + // both W and H are divided by proj_scale_factor + n_patches_sq /= (params.proj_scale_factor * params.proj_scale_factor); + } break; + case PROJECTOR_TYPE_PIXTRAL: + { + // dynamic size + int n_merge = params.spatial_merge_size; + int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1); + int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1); + n_patches_sq = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row + } break; + case PROJECTOR_TYPE_LLAMA4: + { + int scale_factor = ctx->model.hparams.proj_scale_factor; + n_patches_sq /= (scale_factor * scale_factor); + } break; + case PROJECTOR_TYPE_ULTRAVOX: + { + const int proj_stack_factor = ctx->model.hparams.proj_stack_factor; + const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor); + n_patches_sq = n_len / proj_stack_factor / 2; + } break; + case PROJECTOR_TYPE_QWEN2A: + { + // divide by 2 because of whisper + // another divide by 2 because of nn.AvgPool1d(2, stride=2) + n_patches_sq = img->nx / 4; + } break; + default: + GGML_ABORT("unsupported projector type"); } - return n_patches; + return n_patches_sq; } static std::vector>> get_1d_sincos_pos_embed_from_grid_new(int embed_dim, const std::vector> & pos) { @@ -3395,7 +3659,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima ggml_backend_sched_alloc_graph(ctx->sched.get(), gf); // set inputs - const auto & model = ctx->vision_model; + const auto & model = ctx->model; const auto & hparams = model.hparams; const int image_size_width = imgs.entries[0]->nx; @@ -3435,7 +3699,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima }; // set input pixel values - { + if (!imgs.is_audio) { size_t nelem = 0; for (const auto & img : imgs.entries) { nelem += img->nx * img->ny * 3; @@ -3472,10 +3736,20 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } } set_input_f32("inp_raw", inp_raw); + + } else { + // audio input + GGML_ASSERT(imgs.entries.size() == 1); + const auto & mel_inp = imgs.entries[0]; + const int n_step = mel_inp->nx; + const int n_mel = mel_inp->ny; + std::vector inp_raw(n_step * n_mel); + std::memcpy(inp_raw.data(), mel_inp->buf.data(), n_step * n_mel * sizeof(float)); + set_input_f32("inp_raw", inp_raw); } // set input per projector - switch (ctx->proj_type) { + switch (ctx->model.proj_type) { case PROJECTOR_TYPE_MINICPMV: { // inspired from siglip: @@ -3668,6 +3942,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima case PROJECTOR_TYPE_GEMMA3: case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_INTERNVL: + case PROJECTOR_TYPE_QWEN2A: + case PROJECTOR_TYPE_ULTRAVOX: { // do nothing } break; @@ -3727,7 +4003,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima const int n_tokens_out = embeddings->ne[1]; const int expected_n_tokens_out = clip_n_output_tokens(ctx, imgs.entries[0].get()); if (n_tokens_out != expected_n_tokens_out) { - LOG_ERR("%s: expected %d tokens, got %d\n", __func__, expected_n_tokens_out, n_tokens_out); + LOG_ERR("%s: expected output %d tokens, got %d\n", __func__, expected_n_tokens_out, n_tokens_out); GGML_ABORT("Invalid number of output tokens"); } @@ -3738,64 +4014,83 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } int clip_n_mmproj_embd(const struct clip_ctx * ctx) { - switch (ctx->proj_type) { + const auto & hparams = ctx->model.hparams; + switch (ctx->model.proj_type) { case PROJECTOR_TYPE_LDP: - return ctx->vision_model.mm_model_block_1_block_2_1_b->ne[0]; + return ctx->model.mm_model_block_1_block_2_1_b->ne[0]; case PROJECTOR_TYPE_LDPV2: - return ctx->vision_model.mm_model_peg_0_b->ne[0]; + return ctx->model.mm_model_peg_0_b->ne[0]; case PROJECTOR_TYPE_MLP: case PROJECTOR_TYPE_PIXTRAL: - return ctx->vision_model.mm_2_w->ne[1]; + return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_MLP_NORM: - return ctx->vision_model.mm_3_b->ne[0]; + return ctx->model.mm_3_b->ne[0]; case PROJECTOR_TYPE_MINICPMV: - if (ctx->minicpmv_version == 2) { + if (hparams.minicpmv_version == 2) { return 4096; - } else if (ctx->minicpmv_version == 3) { + } else if (hparams.minicpmv_version == 3) { return 3584; - } else if (ctx->minicpmv_version == 4) { + } else if (hparams.minicpmv_version == 4) { return 3584; } GGML_ABORT("Unknown minicpmv version"); case PROJECTOR_TYPE_GLM_EDGE: - return ctx->vision_model.mm_model_mlp_3_w->ne[1]; + return ctx->model.mm_model_mlp_3_w->ne[1]; case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: - return ctx->vision_model.mm_1_b->ne[0]; + return ctx->model.mm_1_b->ne[0]; case PROJECTOR_TYPE_GEMMA3: - return ctx->vision_model.mm_input_proj_w->ne[0]; + return ctx->model.mm_input_proj_w->ne[0]; case PROJECTOR_TYPE_IDEFICS3: - return ctx->vision_model.projection->ne[1]; + return ctx->model.projection->ne[1]; + case PROJECTOR_TYPE_ULTRAVOX: + return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_INTERNVL: - return ctx->vision_model.mm_3_w->ne[1]; + return ctx->model.mm_3_w->ne[1]; case PROJECTOR_TYPE_LLAMA4: - return ctx->vision_model.mm_model_proj->ne[1]; + return ctx->model.mm_model_proj->ne[1]; + case PROJECTOR_TYPE_QWEN2A: + return ctx->model.mm_fc_w->ne[1]; default: GGML_ABORT("Unknown projector type"); } } int clip_is_minicpmv(const struct clip_ctx * ctx) { - if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) { - return ctx->minicpmv_version; + if (ctx->proj_type() == PROJECTOR_TYPE_MINICPMV) { + return ctx->model.hparams.minicpmv_version; } return 0; } bool clip_is_glm(const struct clip_ctx * ctx) { - return ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE; + return ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE; } bool clip_is_qwen2vl(const struct clip_ctx * ctx) { - return ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL; + return ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL + || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL; } bool clip_is_llava(const struct clip_ctx * ctx) { - return ctx->has_llava_projector; + return ctx->model.hparams.has_llava_projector; } bool clip_is_gemma3(const struct clip_ctx * ctx) { - return ctx->proj_type == PROJECTOR_TYPE_GEMMA3; + return ctx->proj_type() == PROJECTOR_TYPE_GEMMA3; +} + +bool clip_has_vision_encoder(const struct clip_ctx * ctx) { + return ctx->model.modality == CLIP_MODALITY_VISION; +} + +bool clip_has_audio_encoder(const struct clip_ctx * ctx) { + return ctx->model.modality == CLIP_MODALITY_AUDIO; +} + +bool clip_has_whisper_encoder(const struct clip_ctx * ctx) { + return ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX + || ctx->proj_type() == PROJECTOR_TYPE_QWEN2A; } bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) { @@ -3816,5 +4111,16 @@ bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, // projector_type clip_get_projector_type(const struct clip_ctx * ctx) { - return ctx->proj_type; + return ctx->proj_type(); +} + +void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel, int n_frames, float * mel) { + clip_image_f32 * audio = new clip_image_f32; + audio->nx = n_frames; + audio->ny = n_mel; + audio->buf.resize(n_frames * n_mel); + std::memcpy(audio->buf.data(), mel, n_frames * n_mel * sizeof(float)); + + batch->entries.push_back(clip_image_f32_ptr(audio)); + batch->is_audio = true; } diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index e7a1c0782..cb2eb261f 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -4,6 +4,8 @@ #include #include +// !!! Internal header, to be used by mtmd only !!! + struct clip_ctx; struct clip_image_size { @@ -15,12 +17,22 @@ struct clip_image_f32; struct clip_image_u8_batch; struct clip_image_f32_batch; +enum clip_modality { + CLIP_MODALITY_VISION, + CLIP_MODALITY_AUDIO, +}; + struct clip_context_params { bool use_gpu; enum ggml_log_level verbosity; }; -struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_params); +struct clip_init_result { + struct clip_ctx * ctx_v; // vision context + struct clip_ctx * ctx_a; // audio context +}; + +struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_params); void clip_free(struct clip_ctx * ctx); @@ -93,3 +105,10 @@ bool clip_is_llava(const struct clip_ctx * ctx); bool clip_is_gemma3(const struct clip_ctx * ctx); bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); + +// use by audio input +void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel, int n_frames, float * mel); + +bool clip_has_vision_encoder(const struct clip_ctx * ctx); +bool clip_has_audio_encoder(const struct clip_ctx * ctx); +bool clip_has_whisper_encoder(const struct clip_ctx * ctx); diff --git a/tools/mtmd/mtmd-audio.cpp b/tools/mtmd/mtmd-audio.cpp new file mode 100644 index 000000000..4d053895c --- /dev/null +++ b/tools/mtmd/mtmd-audio.cpp @@ -0,0 +1,769 @@ +#include "mtmd-audio.h" + +#define _USE_MATH_DEFINES // for M_PI +#include +#include +#include +#include +#include +#include +#include + +// most of the code here is copied from whisper.cpp + +// align x to upper multiple of n +#define _ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n)) + +namespace whisper_preprocessor { + +#define SIN_COS_N_COUNT WHISPER_N_FFT +namespace { +struct whisper_global_cache { + // In FFT, we frequently use sine and cosine operations with the same values. + // We can use precalculated values to speed up the process. + float sin_vals[SIN_COS_N_COUNT]; + float cos_vals[SIN_COS_N_COUNT]; + + // Hann window (Use cosf to eliminate difference) + // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html + // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147 + float hann_window[WHISPER_N_FFT]; + + whisper_global_cache() { + fill_sin_cos_table(); + fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window); + } + + void fill_sin_cos_table() { + for (int i = 0; i < SIN_COS_N_COUNT; i++) { + double theta = (2 * M_PI * i) / SIN_COS_N_COUNT; + sin_vals[i] = sinf(theta); + cos_vals[i] = cosf(theta); + } + } + + void fill_hann_window(int length, bool periodic, float * output) { + int offset = -1; + if (periodic) { + offset = 0; + } + for (int i = 0; i < length; i++) { + output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); + } + } +} global_cache; +} + +// naive Discrete Fourier Transform +// input is real-valued +// output is complex-valued +static void dft(const float* in, int N, float* out) { + const int sin_cos_step = SIN_COS_N_COUNT / N; + + for (int k = 0; k < N; k++) { + float re = 0; + float im = 0; + + for (int n = 0; n < N; n++) { + int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N + re += in[n]*global_cache.cos_vals[idx]; // cos(t) + im -= in[n]*global_cache.sin_vals[idx]; // sin(t) + } + + out[k*2 + 0] = re; + out[k*2 + 1] = im; + } +} + +// Cooley-Tukey FFT +// poor man's implementation - use something better +// input is real-valued +// output is complex-valued +static void fft(float* in, int N, float* out) { + if (N == 1) { + out[0] = in[0]; + out[1] = 0; + return; + } + + const int half_N = N / 2; + if (N - half_N*2 == 1) { + dft(in, N, out); + return; + } + + float* even = in + N; + for (int i = 0; i < half_N; ++i) { + even[i]= in[2*i]; + } + float* even_fft = out + 2 * N; + fft(even, half_N, even_fft); + + float* odd = even; + for (int i = 0; i < half_N; ++i) { + odd[i] = in[2*i + 1]; + } + float* odd_fft = even_fft + N; + fft(odd, half_N, odd_fft); + + const int sin_cos_step = SIN_COS_N_COUNT / N; + for (int k = 0; k < half_N; k++) { + int idx = k * sin_cos_step; // t = 2*M_PI*k/N + float re = global_cache.cos_vals[idx]; // cos(t) + float im = -global_cache.sin_vals[idx]; // sin(t) + + float re_odd = odd_fft[2*k + 0]; + float im_odd = odd_fft[2*k + 1]; + + out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd; + out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd; + + out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd; + out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd; + } +} + +static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector & samples, + int n_samples, int frame_size, int frame_step, int n_threads, + const whisper_filters & filters, whisper_mel & mel) { + std::vector fft_in(frame_size * 2, 0.0); + std::vector fft_out(frame_size * 2 * 2 * 2); + + int n_fft = filters.n_fft; + int i = ith; + + // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist + WHISPER_ASSERT(n_fft == 1 + (frame_size / 2)); + + // calculate FFT only when fft_in are not all zero + for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) { + const int offset = i * frame_step; + + // apply Hann window (~10% faster) + for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) { + fft_in[j] = hann[j] * samples[offset + j]; + } + + // fill the rest with zeros + if (n_samples - offset < frame_size) { + std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0); + } + + // FFT + fft(fft_in.data(), frame_size, fft_out.data()); + + // Calculate modulus^2 of complex numbers + // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. + for (int j = 0; j < n_fft; j++) { + fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]); + } + + // mel spectrogram + for (int j = 0; j < mel.n_mel; j++) { + double sum = 0.0; + // unroll loop (suggested by GH user @lunixbochs) + int k = 0; + for (k = 0; k < n_fft - 3; k += 4) { + sum += + fft_out[k + 0] * filters.data[j * n_fft + k + 0] + + fft_out[k + 1] * filters.data[j * n_fft + k + 1] + + fft_out[k + 2] * filters.data[j * n_fft + k + 2] + + fft_out[k + 3] * filters.data[j * n_fft + k + 3]; + } + // handle n_fft remainder + for (; k < n_fft; k++) { + sum += fft_out[k] * filters.data[j * n_fft + k]; + } + sum = log10(std::max(sum, 1e-10)); + mel.data[j * mel.n_len + i] = sum; + } + } + + // Otherwise fft_out are all zero + double sum = log10(1e-10); + for (; i < mel.n_len; i += n_threads) { + for (int j = 0; j < mel.n_mel; j++) { + mel.data[j * mel.n_len + i] = sum; + } + } +} + +// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157 +static bool log_mel_spectrogram( + const float * samples, + const int n_samples, + const int /*sample_rate*/, + const int frame_size, + const int frame_step, + const int n_mel, + const int n_threads, + const whisper_filters & filters, + const bool debug, + whisper_mel & mel) { + //const int64_t t_start_us = ggml_time_us(); + + // Hann window + WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size"); + const float * hann = global_cache.hann_window; + + // Calculate the length of padding + int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30; + int64_t stage_2_pad = frame_size / 2; + + // Initialize a vector and copy data from C array to it. + std::vector samples_padded; + samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2); + std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad); + + // pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio + std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0); + + // reflective pad 200 samples at the beginning of audio + std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin()); + + mel.n_mel = n_mel; + // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936 + // Calculate number of frames + remove the last frame + mel.n_len = (samples_padded.size() - frame_size) / frame_step; + // Calculate semi-padded sample length to ensure compatibility + mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step; + mel.data.resize(mel.n_mel * mel.n_len); + + { + std::vector workers(n_threads - 1); + for (int iw = 0; iw < n_threads - 1; ++iw) { + workers[iw] = std::thread( + log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded), + n_samples + stage_2_pad, frame_size, frame_step, n_threads, + std::cref(filters), std::ref(mel)); + } + + // main thread + log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel); + + for (int iw = 0; iw < n_threads - 1; ++iw) { + workers[iw].join(); + } + } + + // clamping and normalization + double mmax = -1e20; + for (int i = 0; i < mel.n_mel*mel.n_len; i++) { + if (mel.data[i] > mmax) { + mmax = mel.data[i]; + } + } + + mmax -= 8.0; + + for (int i = 0; i < mel.n_mel*mel.n_len; i++) { + if (mel.data[i] < mmax) { + mel.data[i] = mmax; + } + + mel.data[i] = (mel.data[i] + 4.0)/4.0; + } + + // Dump log_mel_spectrogram + if (debug) { + std::ofstream outFile("log_mel_spectrogram.json"); + outFile << "["; + for (uint64_t i = 0; i < mel.data.size() - 1; i++) { + outFile << mel.data[i] << ", "; + } + outFile << mel.data[mel.data.size() - 1] << "]"; + outFile.close(); + } + + return true; +} + +bool preprocess_audio( + const float * samples, + size_t n_samples, + const whisper_filters & filters, + std::vector & output) { + + if (n_samples == 0) { + // empty audio + return false; + } + + whisper_mel out_full; + bool ok = log_mel_spectrogram( + samples, + n_samples, + COMMON_SAMPLE_RATE, + WHISPER_N_FFT, + WHISPER_HOP_LENGTH, + filters.n_mel, + 4, // n_threads + filters, + false, // debug + out_full); + if (!ok) { + return false; + } + + // because the cgraph in clip.cpp only accepts 3000 frames each, we need to split the mel + // we always expect the mel to have 3000 silent frames at the end + // printf("n_len %d\n", out_full.n_len); + const size_t frames_per_chunk = 3000; + GGML_ASSERT((size_t)out_full.n_len > frames_per_chunk); + for (size_t off = 0; off < (size_t)out_full.n_len; off += frames_per_chunk) { + int n_len = std::min(frames_per_chunk, (size_t)out_full.n_len - off); + if ((size_t)n_len < frames_per_chunk) { + break; // last uncomplete chunk will always be a padded chunk, safe to ignore + } + + whisper_mel out_chunk; + out_chunk.n_len = n_len; + out_chunk.n_mel = out_full.n_mel; + out_chunk.n_len_org = out_full.n_mel; // unused + out_chunk.data.reserve(out_chunk.n_mel * out_chunk.n_len); + + for (int i = 0; i < out_full.n_mel; i++) { + auto src = out_full.data.begin() + i*out_full.n_len + off; + out_chunk.data.insert(out_chunk.data.end(), src, src + frames_per_chunk); + } + + output.push_back(std::move(out_chunk)); + } + + return true; +} + +} // namespace whisper_preprocessor + + +// precalculated mel filter banks +// values are multiplied by 1000.0 to save space, and will be divided by 1000.0 in the end of the function +// +// generated from python code: +// +// from numpy import load +// data = load('mel_filters.npz') +// lst = data.files +// for item in lst: +// print(item) +// print(data[item].shape) +// n_mel = data[item].shape[0] +// n_fft = data[item].shape[1] +// for i, row in enumerate(data[item]): +// for j, val in enumerate(row): +// val = val * 1000.0 +// if val != 0: +// print(f"data[{i*n_fft + j}] = {val:.6f};") + +namespace whisper_precalc_filters { + +whisper_preprocessor::whisper_filters get_128_bins() { + whisper_preprocessor::whisper_filters filters; + filters.n_mel = 128; + filters.n_fft = 201; + std::vector data(filters.n_mel * filters.n_fft, 0.0f); + + data[1] = 12.37398665; + data[202] = 30.39256483; + data[404] = 24.74797331; + data[605] = 18.01857911; + data[807] = 37.12195903; + data[1008] = 5.64459199; + data[1009] = 6.72939420; + data[1210] = 36.03715822; + data[1412] = 19.10337992; + data[1613] = 23.66316877; + data[1815] = 31.47736564; + data[2016] = 11.28918398; + data[2017] = 1.08480197; + data[2218] = 41.68175161; + data[2420] = 13.45878839; + data[2621] = 29.30776216; + data[2823] = 25.83277412; + data[3024] = 16.93377644; + data[3226] = 38.20675984; + data[3427] = 4.55979025; + data[3428] = 7.81419594; + data[3629] = 34.95235741; + data[3831] = 20.18818259; + data[4032] = 22.57836796; + data[4234] = 32.56217018; + data[4435] = 10.20438317; + data[4436] = 2.16960395; + data[4637] = 40.59694707; + data[4839] = 14.54358920; + data[5040] = 28.22295949; + data[5242] = 26.91757679; + data[5443] = 15.84897563; + data[5645] = 39.29156065; + data[5846] = 3.47498828; + data[5847] = 8.89899861; + data[6048] = 33.86755288; + data[6250] = 21.27298526; + data[6451] = 21.49356715; + data[6653] = 33.64697099; + data[6854] = 9.11958050; + data[6855] = 3.25440569; + data[7056] = 39.51214626; + data[7258] = 15.62839188; + data[7459] = 27.13815868; + data[7661] = 28.00237760; + data[7862] = 14.76417296; + data[8064] = 40.37636518; + data[8265] = 2.38068704; + data[8266] = 10.20263787; + data[8467] = 31.61146119; + data[8669] = 24.54700135; + data[8870] = 15.32919332; + data[8871] = 1.66583748; + data[9072] = 36.72905266; + data[9274] = 20.09709924; + data[9475] = 16.93102531; + data[9476] = 2.90265540; + data[9677] = 32.84499049; + data[9879] = 23.52004871; + data[10080] = 11.03894413; + data[10081] = 10.72582975; + data[10282] = 22.71829173; + data[10484] = 32.27872774; + data[10685] = 0.11626833; + data[10686] = 22.85348251; + data[10887] = 8.56344029; + data[10888] = 14.97978810; + data[11089] = 15.51398356; + data[11090] = 8.51490628; + data[11291] = 21.10680379; + data[11292] = 3.32652032; + data[11493] = 25.47064796; + data[11695] = 27.35907957; + data[11896] = 0.65853616; + data[11897] = 23.83812517; + data[12098] = 3.44359246; + data[12099] = 21.22455277; + data[12300] = 5.35842171; + data[12301] = 19.42555793; + data[12502] = 6.49324711; + data[12503] = 18.35542172; + data[12704] = 6.93138083; + data[12705] = 17.93504693; + data[12906] = 6.74968259; + data[12907] = 18.09151843; + data[13108] = 6.01899112; + data[13109] = 18.75767298; + data[13310] = 4.80452832; + data[13311] = 19.87172849; + data[13512] = 3.16627859; + data[13513] = 21.37690969; + data[13514] = 1.25317345; + data[13714] = 1.15934468; + data[13715] = 20.80361731; + data[13716] = 4.04486805; + data[13917] = 17.55363122; + data[13918] = 7.08320038; + data[14119] = 14.07538634; + data[14120] = 10.32655034; + data[14321] = 10.40921453; + data[14322] = 13.73696327; + data[14523] = 6.59187697; + data[14524] = 17.27988198; + data[14525] = 1.46804214; + data[14725] = 2.65681883; + data[14726] = 18.09193194; + data[14727] = 5.85655728; + data[14928] = 13.34277913; + data[14929] = 10.28267574; + data[15130] = 8.56800377; + data[15131] = 14.72230814; + data[15132] = 1.04039861; + data[15332] = 3.79085587; + data[15333] = 17.14678481; + data[15334] = 6.11609267; + data[15535] = 11.75929047; + data[15536] = 11.13393717; + data[15737] = 6.43857848; + data[15738] = 16.07806236; + data[15739] = 4.23917221; + data[15939] = 1.19989377; + data[15940] = 12.75671553; + data[15941] = 9.65298992; + data[16142] = 7.06935255; + data[16143] = 14.94054683; + data[16144] = 4.19024844; + data[16344] = 1.51483389; + data[16345] = 12.00899947; + data[16346] = 9.84823331; + data[16547] = 6.10224018; + data[16548] = 15.33857174; + data[16549] = 5.57676842; + data[16749] = 0.36827257; + data[16750] = 9.89749376; + data[16751] = 11.35340426; + data[16752] = 2.05122307; + data[16952] = 3.89297144; + data[16953] = 12.97352277; + data[16954] = 8.06631614; + data[17155] = 6.74493238; + data[17156] = 13.85874674; + data[17157] = 5.41190524; + data[17357] = 0.74220158; + data[17358] = 8.98779090; + data[17359] = 11.37871388; + data[17360] = 3.32958088; + data[17560] = 2.82313535; + data[17561] = 10.68049297; + data[17562] = 9.43340641; + data[17563] = 1.76325557; + data[17763] = 4.39018616; + data[17764] = 11.87758986; + data[17765] = 7.97005836; + data[17766] = 0.66104700; + data[17966] = 5.49466675; + data[17967] = 12.62953598; + data[17968] = 6.93987962; + data[18169] = 6.18401915; + data[18170] = 12.93473132; + data[18171] = 6.29778765; + data[18371] = 0.02325210; + data[18372] = 6.50206627; + data[18373] = 12.32661773; + data[18374] = 6.00216538; + data[18574] = 0.31548753; + data[18575] = 6.48925547; + data[18576] = 12.04130240; + data[18577] = 6.01462880; + data[18777] = 0.29979556; + data[18778] = 6.18288014; + data[18779] = 12.04272825; + data[18780] = 6.29981188; + data[18781] = 0.55689598; + data[18980] = 0.01120471; + data[18981] = 5.61729167; + data[18982] = 11.22337859; + data[18983] = 6.82516303; + data[18984] = 1.35264499; + data[19184] = 4.82410006; + data[19185] = 10.16623247; + data[19186] = 7.56075513; + data[19187] = 2.34590308; + data[19387] = 3.83235747; + data[19388] = 8.92296247; + data[19389] = 8.47910438; + data[19390] = 3.50978645; + data[19590] = 2.66873185; + data[19591] = 7.51965167; + data[19592] = 9.55500547; + data[19593] = 4.81966138; + data[19594] = 0.08431751; + data[19793] = 1.35767367; + data[19794] = 5.98019501; + data[19795] = 10.60271543; + data[19796] = 6.25298498; + data[19797] = 1.74059917; + data[19997] = 4.32644226; + data[19998] = 8.73131864; + data[19999] = 7.78916525; + data[20000] = 3.48923868; + data[20200] = 2.57835095; + data[20201] = 6.77582854; + data[20202] = 9.40941647; + data[20203] = 5.31194592; + data[20204] = 1.21447595; + data[20403] = 0.75411191; + data[20404] = 4.75395704; + data[20405] = 8.75380263; + data[20406] = 7.19209015; + data[20407] = 3.28754401; + data[20607] = 2.68179690; + data[20608] = 6.49331464; + data[20609] = 9.11457930; + data[20610] = 5.39387390; + data[20611] = 1.67316827; + data[20810] = 0.57394296; + data[20811] = 4.20600036; + data[20812] = 7.83805829; + data[20813] = 7.52023002; + data[20814] = 3.97470826; + data[20815] = 0.42918732; + data[21014] = 1.90464477; + data[21015] = 5.36569161; + data[21016] = 8.82673822; + data[21017] = 6.27609482; + data[21018] = 2.89750961; + data[21218] = 2.89885257; + data[21219] = 6.19694078; + data[21220] = 8.56699049; + data[21221] = 5.34748193; + data[21222] = 2.12797290; + data[21421] = 0.44750227; + data[21422] = 3.59030394; + data[21423] = 6.73310598; + data[21424] = 7.77023612; + data[21425] = 4.70231380; + data[21426] = 1.63439126; + data[21625] = 1.01536023; + data[21626] = 4.01018746; + data[21627] = 7.00501446; + data[21628] = 7.23442994; + data[21629] = 4.31095669; + data[21630] = 1.38748321; + data[21829] = 1.33348850; + data[21830] = 4.18730825; + data[21831] = 7.04112789; + data[21832] = 6.93188375; + data[21833] = 4.14605811; + data[21834] = 1.36023236; + data[22033] = 1.42879714; + data[22034] = 4.14824858; + data[22035] = 6.86769979; + data[22036] = 6.83705276; + data[22037] = 4.18239459; + data[22038] = 1.52773573; + data[22237] = 1.32610439; + data[22238] = 3.91751388; + data[22239] = 6.50892360; + data[22240] = 6.92639686; + data[22241] = 4.39672917; + data[22242] = 1.86706171; + data[22441] = 1.04827771; + data[22442] = 3.51767405; + data[22443] = 5.98707050; + data[22444] = 7.17824046; + data[22445] = 4.76767914; + data[22446] = 2.35711760; + data[22645] = 0.61636406; + data[22646] = 2.96949223; + data[22647] = 5.32262027; + data[22648] = 7.57265091; + data[22649] = 5.27558755; + data[22650] = 2.97852419; + data[22651] = 0.68146095; + data[22849] = 0.04971400; + data[22850] = 2.29204819; + data[22851] = 4.53438237; + data[22852] = 6.77671656; + data[22853] = 5.90240723; + data[22854] = 3.71349836; + data[22855] = 1.52458926; + data[23054] = 1.50285335; + data[23055] = 3.63961048; + data[23056] = 5.77636715; + data[23057] = 6.63159089; + data[23058] = 4.54574358; + data[23059] = 2.45989650; + data[23060] = 0.37404924; + data[23258] = 0.61795861; + data[23259] = 2.65410915; + data[23260] = 4.69025923; + data[23261] = 6.72641024; + data[23262] = 5.46034705; + data[23263] = 3.47270933; + data[23264] = 1.48507138; + data[23463] = 1.59233576; + data[23464] = 3.53261665; + data[23465] = 5.47289755; + data[23466] = 6.44368259; + data[23467] = 4.54962999; + data[23468] = 2.65557761; + data[23469] = 0.76152512; + data[23667] = 0.46749352; + data[23668] = 2.31641904; + data[23669] = 4.16534441; + data[23670] = 6.01426978; + data[23671] = 5.67844696; + data[23672] = 3.87357362; + data[23673] = 2.06870004; + data[23674] = 0.26382666; + data[23872] = 1.05349103; + data[23873] = 2.81536230; + data[23874] = 4.57723346; + data[23875] = 6.33910485; + data[23876] = 5.12815686; + data[23877] = 3.40826320; + data[23878] = 1.68837002; + data[24077] = 1.43350090; + data[24078] = 3.11241671; + data[24079] = 4.79133241; + data[24080] = 6.40943693; + data[24081] = 4.77052201; + data[24082] = 3.13160778; + data[24083] = 1.49269309; + data[24281] = 0.02932359; + data[24282] = 1.62918994; + data[24283] = 3.22905602; + data[24284] = 4.82892245; + data[24285] = 6.14671456; + data[24286] = 4.58496623; + data[24287] = 3.02321767; + data[24288] = 1.46146910; + data[24486] = 0.13601698; + data[24487] = 1.66055572; + data[24488] = 3.18509457; + data[24489] = 4.70963307; + data[24490] = 6.04072399; + data[24491] = 4.55250870; + data[24492] = 3.06429295; + data[24493] = 1.57607743; + data[24494] = 0.08786193; + data[24691] = 0.09328097; + data[24692] = 1.54603878; + data[24693] = 2.99879676; + data[24694] = 4.45155473; + data[24695] = 5.90431225; + data[24696] = 4.65566106; + data[24697] = 3.23751615; + data[24698] = 1.81937125; + data[24699] = 0.40122634; + data[24897] = 1.30262633; + data[24898] = 2.68698297; + data[24899] = 4.07133950; + data[24900] = 5.45569602; + data[24901] = 4.87832492; + data[24902] = 3.52695142; + data[24903] = 2.17557792; + data[24904] = 0.82420459; + data[25102] = 0.94595028; + data[25103] = 2.26512621; + data[25104] = 3.58430226; + data[25105] = 4.90347855; + data[25106] = 5.20569785; + data[25107] = 3.91795207; + data[25108] = 2.63020652; + data[25109] = 1.34246063; + data[25110] = 0.05471494; + data[25307] = 0.49037894; + data[25308] = 1.74744334; + data[25309] = 3.00450763; + data[25310] = 4.26157191; + data[25311] = 5.51863620; + data[25312] = 4.39707236; + data[25313] = 3.16995848; + data[25314] = 1.94284460; + data[25315] = 0.71573065; + data[25513] = 1.14698056; + data[25514] = 2.34485767; + data[25515] = 3.54273478; + data[25516] = 4.74061165; + data[25517] = 4.95198462; + data[25518] = 3.78264743; + data[25519] = 2.61331047; + data[25520] = 1.44397374; + data[25521] = 0.27463681; + data[25718] = 0.47569509; + data[25719] = 1.61717169; + data[25720] = 2.75864848; + data[25721] = 3.90012516; + data[25722] = 5.04160160; + data[25723] = 4.45712078; + data[25724] = 3.34284059; + data[25725] = 2.22856039; + data[25726] = 1.11428020; + + for (auto & val : data) { + val /= 1000.0f; + } + + filters.data = std::move(data); + return filters; +} + +} // namespace whisper_precalc_filters diff --git a/tools/mtmd/mtmd-audio.h b/tools/mtmd/mtmd-audio.h new file mode 100644 index 000000000..b7b940aff --- /dev/null +++ b/tools/mtmd/mtmd-audio.h @@ -0,0 +1,47 @@ +#pragma once + +#include "ggml.h" + +#include +#include +#include + +#define WHISPER_ASSERT GGML_ASSERT + +#define WHISPER_SAMPLE_RATE 16000 +#define WHISPER_N_FFT 400 +#define WHISPER_HOP_LENGTH 160 +#define WHISPER_CHUNK_SIZE 30 + +#define COMMON_SAMPLE_RATE 16000 + +namespace whisper_preprocessor { + +struct whisper_mel { + int n_len; + int n_len_org; + int n_mel; + + std::vector data; +}; + +struct whisper_filters { + int32_t n_mel; + int32_t n_fft; + + std::vector data; +}; + +bool preprocess_audio( + const float * samples, + size_t n_samples, + const whisper_filters & filters, + std::vector & output); + +} // namespace whisper_preprocessor + +namespace whisper_precalc_filters { + +whisper_preprocessor::whisper_filters get_128_bins(); + +} // namespace whisper_precalc_filters diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index 4977d5480..508a64c58 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -7,6 +7,7 @@ #include "console.h" #include "chat.h" #include "mtmd.h" +#include "mtmd-helper.h" #include #include @@ -37,10 +38,10 @@ static volatile bool g_is_interrupted = false; static void show_additional_info(int /*argc*/, char ** argv) { LOG( "Experimental CLI for multimodal\n\n" - "Usage: %s [options] -m --mmproj --image -p \n\n" + "Usage: %s [options] -m --mmproj --image --audio