diff --git a/.all_crates.sh b/.all_crates.sh index 859bbff144..8ba76a3531 100644 --- a/.all_crates.sh +++ b/.all_crates.sh @@ -1,2 +1,2 @@ -ALL_CRATES_PATH="data linalg core nnef nnef/nnef-resources pulse-opl pulse extra transformers hir tflite tensorflow onnx-opl onnx gpu metal cuda libcli api api/rs api/ffi api/proxy/sys api/proxy cli" +ALL_CRATES_PATH="data linalg core nnef nnef/nnef-resources transformers pulse-opl pulse extra hir tflite tensorflow onnx-opl onnx gpu metal cuda libcli api api/rs api/ffi api/proxy/sys api/proxy cli" diff --git a/.github/scripts/inject_wheel_sboms.py b/.github/scripts/inject_wheel_sboms.py new file mode 100644 index 0000000000..fcb6c17f47 --- /dev/null +++ b/.github/scripts/inject_wheel_sboms.py @@ -0,0 +1,85 @@ +"""Inject CycloneDX + SPDX SBOMs into a wheel's `.dist-info/sboms/` per PEP 770. + +Reads the wheel, generates SBOMs from its unpacked contents via `syft` +(which understands the `cargo-auditable` section embedded in the +bundled Rust dylib), drops them into the wheel's metadata directory, +and re-packs the wheel. `wheel pack` regenerates RECORD so the new +files are properly listed and hashed. + +Usage: python inject_wheel_sboms.py wheel-1.whl wheel-2.whl ... + (in-place; replaces each input wheel with the SBOM-bearing one) + +Requires: `syft` on PATH, and the `wheel` Python package installed. +""" + +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path + + +def inject(wheel_path: Path) -> None: + wheel_path = wheel_path.resolve() + with tempfile.TemporaryDirectory() as tmp: + tmp = Path(tmp) + unpack_root = tmp / "unpacked" + repack_root = tmp / "repacked" + unpack_root.mkdir() + repack_root.mkdir() + + subprocess.check_call( + [sys.executable, "-m", "wheel", "unpack", "-d", str(unpack_root), str(wheel_path)] + ) + + # `wheel unpack` writes one top-level dir named `-` + (unpacked,) = list(unpack_root.iterdir()) + (dist_info,) = list(unpacked.glob("*.dist-info")) + + sboms_dir = dist_info / "sboms" + sboms_dir.mkdir(exist_ok=True) + + # syft scans the unpacked tree. Its rust-audit-binary cataloger + # reads the `.dep-v0` ELF/Mach-O section that `cargo-auditable` + # embedded; the Python cataloger picks up METADATA. + subprocess.check_call( + [ + "syft", + "scan", + f"dir:{unpacked}", + "--source-name", + unpacked.name, + "-o", + f"cyclonedx-json={sboms_dir / 'sbom.cdx.json'}", + "-o", + f"spdx-json={sboms_dir / 'sbom.spdx.json'}", + ] + ) + + # `wheel pack` rewrites RECORD with hashes for every file under + # `unpacked/`, including the two SBOMs we just added. + subprocess.check_call( + [sys.executable, "-m", "wheel", "pack", "-d", str(repack_root), str(unpacked)] + ) + + (repacked_wheel,) = list(repack_root.glob("*.whl")) + # Names should match; if `wheel pack` produced a different + # filename (e.g. build-tag difference), prefer the new name. + target = wheel_path.parent / repacked_wheel.name + if target != wheel_path: + wheel_path.unlink() + shutil.move(str(repacked_wheel), str(target)) + print(f"injected SBOMs into {target.name}") + + +def main(argv: list[str]) -> int: + if not argv: + print(__doc__, file=sys.stderr) + return 2 + for w in argv: + inject(Path(w)) + return 0 + + +if __name__ == "__main__": + sys.exit(main(sys.argv[1:])) diff --git a/.github/workflows/asan.yml b/.github/workflows/asan.yml index e0b8385e70..5db9fcfe26 100644 --- a/.github/workflows/asan.yml +++ b/.github/workflows/asan.yml @@ -22,7 +22,7 @@ jobs: runs-on: ${{matrix.os}} steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - name: Rustup update diff --git a/.github/workflows/binaries.yml b/.github/workflows/binaries.yml index 6d91302454..8ff670967f 100644 --- a/.github/workflows/binaries.yml +++ b/.github/workflows/binaries.yml @@ -22,6 +22,11 @@ jobs: name: Upload Release Binaries permissions: contents: write + # OIDC + attestations: cryptographically sign each released .tgz + # against its SBOMs via GitHub's attestation store. Anyone can + # then verify with `gh attestation verify --owner sonos`. + id-token: write + attestations: write strategy: fail-fast: false matrix: @@ -53,7 +58,7 @@ jobs: runs-on: ${{ matrix.os }} steps: - name: Checkout code - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false @@ -98,15 +103,59 @@ jobs: export CARGO_TARGET_${RUST_TRIPLE_ENV}_LINKER=$TARGET_CC fi - cargo build --target ${target} --release -p tract-cli + # cargo-auditable wraps `cargo build` to embed the resolved + # dependency graph into the binary so consumers can recover + # the SBOM directly from the binary via `cargo audit bin`. + cargo install --locked cargo-auditable --version 0.7.0 + cargo auditable build --target ${target} --release -p tract-cli --locked mkdir tract-$name cp target/${target}/release/tract tract-${name} tar czf tract-${name}.tgz tract-${name} + - name: Generate CycloneDX SBOM + uses: anchore/sbom-action@e22c389904149dbc22b58101806040fa8d37a610 # v0.24.0 + with: + path: . + format: cyclonedx-json + artifact-name: tract-${{ matrix.target }}-${{ steps.version.outputs.value }}.cdx.json + output-file: tract-${{ matrix.target }}-${{ steps.version.outputs.value }}.cdx.json + upload-artifact: false + upload-release-assets: false + + - name: Generate SPDX SBOM + uses: anchore/sbom-action@e22c389904149dbc22b58101806040fa8d37a610 # v0.24.0 + with: + path: . + format: spdx-json + artifact-name: tract-${{ matrix.target }}-${{ steps.version.outputs.value }}.spdx.json + output-file: tract-${{ matrix.target }}-${{ steps.version.outputs.value }}.spdx.json + upload-artifact: false + upload-release-assets: false + + - name: Attest build provenance + uses: actions/attest-build-provenance@a2bbfa25375fe432b6a289bc6b6cd05ecd0c4c32 # v4.1.0 + with: + subject-path: tract-${{ matrix.target }}-${{ steps.version.outputs.value }}.tgz + + - name: Attest release tarball with CycloneDX SBOM + uses: actions/attest-sbom@c604332985a26aa8cf1bdc465b92731239ec6b9e # v4.1.0 + with: + subject-path: tract-${{ matrix.target }}-${{ steps.version.outputs.value }}.tgz + sbom-path: tract-${{ matrix.target }}-${{ steps.version.outputs.value }}.cdx.json + + - name: Attest release tarball with SPDX SBOM + uses: actions/attest-sbom@c604332985a26aa8cf1bdc465b92731239ec6b9e # v4.1.0 + with: + subject-path: tract-${{ matrix.target }}-${{ steps.version.outputs.value }}.tgz + sbom-path: tract-${{ matrix.target }}-${{ steps.version.outputs.value }}.spdx.json + - name: Upload asset uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3 with: - files: tract-${{matrix.target}}-${{ steps.version.outputs.value }}.tgz + files: | + tract-${{matrix.target}}-${{ steps.version.outputs.value }}.tgz + tract-${{matrix.target}}-${{ steps.version.outputs.value }}.cdx.json + tract-${{matrix.target}}-${{ steps.version.outputs.value }}.spdx.json name: ${{ steps.version.outputs.value }} tag_name: v${{ steps.version.outputs.value }} env: diff --git a/.github/workflows/cost_model.yml b/.github/workflows/cost_model.yml index 8af23aeb1e..7e7f9ed260 100644 --- a/.github/workflows/cost_model.yml +++ b/.github/workflows/cost_model.yml @@ -25,7 +25,7 @@ jobs: target: [ "aarch64", "armv7" ] steps: - name: Checkout code - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false diff --git a/.github/workflows/crates.yml b/.github/workflows/crates.yml index 0acd81f364..fce165287c 100644 --- a/.github/workflows/crates.yml +++ b/.github/workflows/crates.yml @@ -19,20 +19,26 @@ jobs: outputs: os: ${{steps.set-matrix.outputs.os}} rust: ${{steps.set-matrix.outputs.rust}} + msrv: ${{steps.set-matrix.outputs.msrv}} steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false - id: set-matrix env: FULL: ${{ github.event_name == 'workflow_dispatch' || github.event_name == 'schedule' }} run: | + msrv=$(grep -m1 '^rust-version' Cargo.toml | sed -E 's/.*"([^"]+)".*/\1/') + echo "msrv=$msrv" >> $GITHUB_OUTPUT if [ "$FULL" == "true" ] then echo 'os=["ubuntu-latest", "macos-latest"]' >> $GITHUB_OUTPUT - echo 'rust=["1.91.0", "stable", "beta", "nightly"]' >> $GITHUB_OUTPUT + echo "rust=[\"$msrv\", \"stable\", \"beta\", \"nightly\"]" >> $GITHUB_OUTPUT else echo ::notice::Skipping macOS checks on PR and commit. Dispatch workflow manually if needed. echo 'os=["ubuntu-latest"]' >> $GITHUB_OUTPUT - echo 'rust=["1.91.0"]' >> $GITHUB_OUTPUT + echo "rust=[\"$msrv\"]" >> $GITHUB_OUTPUT fi crates: @@ -53,7 +59,7 @@ jobs: RUSTUP_TOOLCHAIN: ${{matrix.rust}} steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false @@ -70,7 +76,7 @@ jobs: env: RUSTUP_TOOLCHAIN: ${{matrix.rust}} steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false @@ -85,9 +91,9 @@ jobs: runs-on: cuda-lovelace needs: prepare-matrix env: - RUSTUP_TOOLCHAIN: "1.91.0" + RUSTUP_TOOLCHAIN: ${{ needs.prepare-matrix.outputs.msrv }} steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - name: Minimum-BOM GPU smoke @@ -103,7 +109,7 @@ jobs: env: RUSTUP_TOOLCHAIN: ${{matrix.rust}} steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false @@ -122,12 +128,12 @@ jobs: env: RUSTUP_TOOLCHAIN: ${{matrix.rust}} steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - run: rustup component add clippy && cargo clippy - name: fmt - run: rustup component add rustfmt && cargo fmt --check + run: rustup component add rustfmt --toolchain stable && cargo +stable fmt --check - name: Warnings env: RUSTFLAGS: -D warnings @@ -140,7 +146,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - name: Install cargo-deny diff --git a/.github/workflows/cross-platform.yml b/.github/workflows/cross-platform.yml index f0c491d29b..3b9294555b 100644 --- a/.github/workflows/cross-platform.yml +++ b/.github/workflows/cross-platform.yml @@ -14,7 +14,6 @@ on: env: CARGO_INCREMENTAL: false FORCE_JAVASCRIPT_ACTIONS_TO_NODE20: true - RUSTUP_TOOLCHAIN: 1.91.0 permissions: contents: read @@ -83,7 +82,7 @@ jobs: contents: read steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ needs.prepare.outputs.test_ref }} fetch-depth: 0 @@ -95,7 +94,7 @@ jobs: - name: Configure AWS Credentials continue-on-error: true - uses: aws-actions/configure-aws-credentials@d979d5b3a71173a29b74b5b88418bfda9437d885 # v6 + uses: aws-actions/configure-aws-credentials@acca2b1b2070338fb9fd1ca27ecee81d687e58e5 # v6.1.2 with: role-to-assume: arn:aws:iam::567805100031:role/github-runner-tract-ci aws-region: us-east-2 @@ -141,7 +140,7 @@ jobs: contents: read steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ needs.prepare.outputs.test_ref }} fetch-depth: 0 @@ -149,7 +148,7 @@ jobs: - name: Configure AWS Credentials continue-on-error: true - uses: aws-actions/configure-aws-credentials@d979d5b3a71173a29b74b5b88418bfda9437d885 # v6 + uses: aws-actions/configure-aws-credentials@acca2b1b2070338fb9fd1ca27ecee81d687e58e5 # v6.1.2 with: role-to-assume: arn:aws:iam::567805100031:role/github-runner-tract-ci aws-region: us-east-2 diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index bc48fc6c58..65d78e7737 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -8,7 +8,6 @@ on: env: CARGO_INCREMENTAL: false FORCE_JAVASCRIPT_ACTIONS_TO_NODE20: true - RUSTUP_TOOLCHAIN: 1.91.0 permissions: contents: read @@ -20,7 +19,7 @@ jobs: examples: ${{steps.set-matrix.outputs.examples}} steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - id: set-matrix @@ -40,14 +39,14 @@ jobs: ex: ${{fromJSON(needs.examples.outputs.examples)}} steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - name: Configure AWS Credentials # if: github.repository == 'sonos/tract' continue-on-error: true - uses: aws-actions/configure-aws-credentials@d979d5b3a71173a29b74b5b88418bfda9437d885 # v6 + uses: aws-actions/configure-aws-credentials@acca2b1b2070338fb9fd1ca27ecee81d687e58e5 # v6.1.2 with: role-to-assume: arn:aws:iam::567805100031:role/github-runner-tract-ci aws-region: us-east-2 @@ -64,7 +63,7 @@ jobs: build-tract-cli: runs-on: ubuntu-latest steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - run: cargo build -p tract-cli --profile opt-no-lto @@ -76,7 +75,7 @@ jobs: build-tract-cli-macos: runs-on: macOS steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - run: cargo build -p tract-cli --profile opt-no-lto @@ -91,7 +90,7 @@ jobs: examples: ${{steps.set-matrix.outputs.examples}} steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - id: set-matrix @@ -108,7 +107,7 @@ jobs: ex: ${{fromJSON(needs.gpu-examples.outputs.examples)}} steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false @@ -137,7 +136,7 @@ jobs: ex: ${{fromJSON(needs.gpu-examples.outputs.examples)}} steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false diff --git a/.github/workflows/full.yml b/.github/workflows/full.yml index fba7c61ce5..1d48ae2b98 100644 --- a/.github/workflows/full.yml +++ b/.github/workflows/full.yml @@ -51,7 +51,7 @@ jobs: contents: read needs: prepare steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ needs.prepare.outputs.test_ref }} fetch-depth: 0 @@ -59,7 +59,7 @@ jobs: - name: Configure AWS Credentials continue-on-error: true - uses: aws-actions/configure-aws-credentials@d979d5b3a71173a29b74b5b88418bfda9437d885 # v6 + uses: aws-actions/configure-aws-credentials@acca2b1b2070338fb9fd1ca27ecee81d687e58e5 # v6.1.2 with: role-to-assume: arn:aws:iam::567805100031:role/github-runner-tract-ci aws-region: us-east-2 @@ -72,7 +72,7 @@ jobs: needs: prepare steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ needs.prepare.outputs.test_ref }} fetch-depth: 0 @@ -90,7 +90,7 @@ jobs: opset: [1_4_1, 1_5_0, 1_6_0, 1_7_0, 1_8_1, 1_9_0, 1_10_2, 1_11_0, 1_12_0, 1_13_0, 1_14_1, 1_15_0, 1_16_2, 1_17_0, 1_18_0, 1_19_1] steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ needs.prepare.outputs.test_ref }} fetch-depth: 0 @@ -103,7 +103,7 @@ jobs: needs: prepare steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ needs.prepare.outputs.test_ref }} fetch-depth: 0 @@ -116,7 +116,7 @@ jobs: needs: prepare steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ needs.prepare.outputs.test_ref }} fetch-depth: 0 @@ -132,7 +132,7 @@ jobs: needs: prepare steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ needs.prepare.outputs.test_ref }} fetch-depth: 0 @@ -147,7 +147,7 @@ jobs: needs: prepare steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ needs.prepare.outputs.test_ref }} fetch-depth: 0 @@ -162,7 +162,7 @@ jobs: needs: prepare steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ needs.prepare.outputs.test_ref }} fetch-depth: 0 @@ -177,7 +177,7 @@ jobs: needs: prepare steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ needs.prepare.outputs.test_ref }} fetch-depth: 0 @@ -193,7 +193,7 @@ jobs: needs: prepare steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ needs.prepare.outputs.test_ref }} fetch-depth: 0 diff --git a/.github/workflows/large_models.yml b/.github/workflows/large_models.yml index 05d29e07ec..822a94d3d1 100644 --- a/.github/workflows/large_models.yml +++ b/.github/workflows/large_models.yml @@ -5,6 +5,11 @@ on: schedule: - cron: '0 3 * * *' workflow_dispatch: + inputs: + pr_number: + description: "Optional PR number to test (from fork ok). Leave empty to run on selected branch." + required: false + type: number env: LARGE_MODELS: true @@ -13,15 +18,42 @@ permissions: contents: read jobs: - cli: + prepare: + runs-on: ubuntu-latest + outputs: + test_ref: ${{ steps.set.outputs.test_ref }} + steps: + - id: set + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9 + with: + script: | + const prInput = context.payload.inputs?.pr_number ?? context.payload.pull_request?.number; + let ref; + if (prInput) { + const pr = await github.rest.pulls.get({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: Number(prInput), + }); + ref = pr.data.head.sha; + } else { + ref = process.env.GITHUB_SHA; + } + core.info(`test_ref: ${ref}`); + core.setOutput('test_ref', ref); + + cli: name: Build tract on ${{ matrix.os }} + needs: prepare runs-on: ${{ matrix.os }} strategy: matrix: os: [ macos-latest, ubuntu-latest ] steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: + ref: ${{ needs.prepare.outputs.test_ref }} + fetch-depth: 0 persist-credentials: false - run: | ROOT=. ./.travis/ci-system-setup.sh @@ -33,6 +65,7 @@ jobs: path: ./target/opt-no-lto/tract foundation-llms: + needs: prepare runs-on: ubuntu-latest outputs: models: ${{steps.set-matrix.outputs.models}} @@ -55,7 +88,7 @@ jobs: foundation-llm: name: ${{ matrix.os }} / ${{matrix.rt}} / ${{ matrix.model }} / ${{ matrix.q }} - needs: [ cli, foundation-llms ] + needs: [ prepare, cli, foundation-llms ] runs-on: ${{ matrix.os }} strategy: matrix: @@ -84,12 +117,14 @@ jobs: contents: read steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: + ref: ${{ needs.prepare.outputs.test_ref }} + fetch-depth: 0 persist-credentials: false - name: Configure AWS Credentials continue-on-error: true - uses: aws-actions/configure-aws-credentials@d979d5b3a71173a29b74b5b88418bfda9437d885 # v6 + uses: aws-actions/configure-aws-credentials@acca2b1b2070338fb9fd1ca27ecee81d687e58e5 # v6.1.2 with: role-to-assume: arn:aws:iam::567805100031:role/github-runner-tract-ci aws-region: us-east-2 @@ -119,7 +154,7 @@ jobs: parakeet-tdt-600m-v3: name: ${{matrix.os}} / Parakeet TDT 600m v3 - needs: [ cli ] + needs: [ prepare, cli ] strategy: matrix: os: [ macOS, cuda-lovelace ] @@ -129,8 +164,10 @@ jobs: contents: read runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: + ref: ${{ needs.prepare.outputs.test_ref }} + fetch-depth: 0 persist-credentials: false - run: echo uname=$(uname) >> $GITHUB_ENV - uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 @@ -148,7 +185,7 @@ jobs: nemotron-speech-streaming-en-06b: name: ${{matrix.os}} / Nemotron speech streaming en 0.6b - needs: [ cli ] + needs: [ prepare, cli ] strategy: matrix: os: [ macOS, cuda-lovelace ] @@ -158,8 +195,10 @@ jobs: contents: read runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: + ref: ${{ needs.prepare.outputs.test_ref }} + fetch-depth: 0 persist-credentials: false - run: echo uname=$(uname) >> $GITHUB_ENV - uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 diff --git a/.github/workflows/pydoc.yml b/.github/workflows/pydoc.yml index f362fb08b0..68db3b328a 100644 --- a/.github/workflows/pydoc.yml +++ b/.github/workflows/pydoc.yml @@ -20,7 +20,7 @@ jobs: contents: write steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false @@ -85,13 +85,16 @@ jobs: json.dump(versions, f, indent=2) " - # commit and push + # commit and push. actions/checkout ran with persist-credentials: + # false, so no auth is wired into .git/config — push via an + # explicit token URL instead. git add -A git commit -m "Update Python docs ($version)" || true - git push origin gh-pages + git push "https://x-access-token:${GH_TOKEN}@github.com/${GITHUB_REPOSITORY}.git" gh-pages # clean up worktree cd - git worktree remove "$workdir" env: STEPS_VERSION_OUTPUTS_VALUE: ${{ steps.version.outputs.value }} + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 3c62678048..7b20dad226 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -23,7 +23,7 @@ jobs: id: version run: echo value=$(echo ${GITHUB_REF} | cut -f 3 -d / | sed 's/^v//' ) >> $GITHUB_OUTPUT - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index e9cd14640a..2f573c62fa 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -36,12 +36,13 @@ jobs: os: [ubuntu-22.04, windows-2022, macos-14] steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - name: Setup | Rust - uses: dtolnay/rust-toolchain@29eef336d9b2848a0b548edc03f92a220660cdb8 # stable + shell: bash + run: rustup toolchain install # channel + components from rust-toolchain.toml - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6 with: @@ -63,6 +64,18 @@ jobs: timeout_seconds: 54000 # 15 hours :/ command: uvx cibuildwheel --output-dir wheelhouse api/py + - name: Install syft + uses: anchore/sbom-action/download-syft@e22c389904149dbc22b58101806040fa8d37a610 # v0.24.0 + + - name: Inject CycloneDX + SPDX SBOMs into wheels (PEP 770) + shell: bash + run: | + set -ex + uv pip install --system wheel + for w in wheelhouse/*.whl ; do + python .github/scripts/inject_wheel_sboms.py "$w" + done + - uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: wheels-${{github.run_id}}-${{matrix.os}} @@ -72,7 +85,7 @@ jobs: name: Make SDist runs-on: ubuntu-latest steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index 00c57a32f0..86a196ebea 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -25,7 +25,7 @@ jobs: runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - uses: nick-fields/retry@ad984534de44a9489a53aefd81eb77f87c70dc60 # v4 diff --git a/.gitignore b/.gitignore index eca6887263..226275610d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,6 @@ target **/*.rs.bk *.rustfmt *.back -Cargo.lock examples/data .idea .cached/** diff --git a/.travis/cargo-deny-check.sh b/.travis/cargo-deny-check.sh index 6f6e082bb7..a1e61f446a 100755 --- a/.travis/cargo-deny-check.sh +++ b/.travis/cargo-deny-check.sh @@ -7,4 +7,7 @@ else CARGO_DENY="cargo deny" fi -(cd api/rs ; $CARGO_DENY check) +set -e + +(cd api/rs ; $CARGO_DENY check -c deny.toml) +(cd cli ; $CARGO_DENY check -c deny.toml) diff --git a/.travis/ci-system-setup.sh b/.travis/ci-system-setup.sh index 0aa6d2cf92..4d68e2f4e7 100755 --- a/.travis/ci-system-setup.sh +++ b/.travis/ci-system-setup.sh @@ -3,11 +3,6 @@ set -e [ -d $ROOT/.travis ] || exit 1 "\$ROOT not set correctly '$ROOT'" -if [ -z "$RUSTUP_TOOLCHAIN" ] -then - export RUSTUP_TOOLCHAIN=1.91.0 -fi - export RUSTUP_TOOLCHAIN PATH=$PATH:$HOME/.cargo/bin diff --git a/.travis/native.sh b/.travis/native.sh index 56748b17c5..9cc0ec3fd0 100755 --- a/.travis/native.sh +++ b/.travis/native.sh @@ -2,11 +2,6 @@ set -ex -if [ -z "$RUSTUP_TOOLCHAIN" ] -then - export RUSTUP_TOOLCHAIN=1.91.0 -fi - rustup update cargo update diff --git a/.travis/tf.sh b/.travis/tf.sh deleted file mode 100755 index 2d20db30e6..0000000000 --- a/.travis/tf.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/sh - -set -ex - -if [ -z "$CACHEDIR" ] -then - CACHEDIR=`dirname $0`/../.cached -fi - - -(cd tensorflow; cargo test --release --features conform) diff --git a/AGENTS.md b/AGENTS.md index 4692ec13fb..57f2f2d30d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -58,8 +58,8 @@ cargo test -p tract-core # test the whole workspace cargo test --workspace -# format (always repo-wide, never per-crate; pin the toolchain to avoid spurious diffs) -cargo +1.91.0 fmt --all +# format (always repo-wide, never per-crate; rust-toolchain.toml pins stable, so bare cargo fmt is correct) +cargo fmt --all # lint cargo clippy --workspace @@ -263,13 +263,17 @@ Re-export shims in `transformers/src/ops/mod.rs` keep downstream crates - Comments describe the **current** code only. Don't narrate the diff ("the previous code did X", "this used to be Y", "was a copy-paste of the 32x1 kernel") -- that history belongs in the commit message and will be wrong after the next refactor. - Avoid section banners (// -- Step 2: Pad -> Reshape --), prefer split in functions. It's ok to have long function prototype in private function (within reason) #[allow(clippy::too_many_arguments)] authorized in such case. + Doc comments (`///` / `//!`): + - Unlike inline comments, these are encouraged. tract has historically under-documented its items -- do add a concise doc comment on public / non-trivial items (ops, declutter & codegen passes, public fns) stating what it is, its contract, valid inputs, and which rules it interacts with. + - Same anti-narration rule as inline: document the **current** contract, not benchmarks, perf numbers ("0.77 ms vs ORT's 1.13 ms"), issue numbers, or change history ("Measured on...", "Regression:..."). + Idioms: - Prefer `as_X()` over `to_X().ok()` for cheap reference-style conversions. - No new `unsafe` without explicit permission. `shunt_outside_unchecked` is a last resort for surgical patches whose safety is locally obvious; reach for safe alternatives first. - Don't add abstraction beyond the task. Three similar lines beat a premature helper. Formatting: - - Always run `cargo +1.91.0 fmt --all` before committing -- bare `cargo fmt` uses a newer rustfmt and produces spurious diffs CI rejects. + - Always run `cargo fmt --all` before committing. The repo's `rust-toolchain.toml` pins the stable channel, so bare `cargo fmt` uses the same rustfmt CI checks against -- don't override the toolchain. PR comments and review replies: - Open PR with a short, crystal-clear summary paragraph -- one or two sentences stating what the PR is about and why it matters, before details if necessary. diff --git a/CHANGELOG.md b/CHANGELOG.md index 85f47820ca..ecbb0bb388 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,22 +1,48 @@ -# 0.23.0 — soon + +# 0.22.x → 0.23 in a nutshell + +- **`tract` facade is now the recommended public API.** Renamed from `tract-rs`; sole crate under semver, one curated surface (`Model`/`Runnable`/`State`/`Tensor`/`TDim` + `nnef()`/`onnx()`/`runtime_for_name`). `ndarray` removed from public types in favour of opt-in `impl_ndarray_interop!()`. Most other API renames in this release fall out of this consolidation (`Value`→`Tensor`, `concretize_symbols`→`set_symbols`, `default` runtime → `cpu`, …). +- **GPU is first-class.** `cuda` + `metal` `Runtime` impls with f16 conv + cuDNN, CUDA 13 (CUDA 12 dropped), automatic per-node CPU fallback when GPU rejects a shape; virtual `gpu` / `gpu-or-cpu` names for portable downstream code. +- **Supply-chain hardening.** ` Cargo.lock` tracked; CDX + SPDX SBOMs on release binaries plus PEP 770 SBOMs in wheels (`cargo-auditable` + GitHub attestations); + +## Migrating from 0.22.x to 0.23 + +For normal usage we recommend adopting the **`tract` facade crate** (the public API at `api/rs`) instead of wiring `tract-core`, `tract-nnef`, `tract-onnx`, `tract-pulse`, `tract-cuda`, `tract-metal`, etc. directly. The facade exposes one stable surface — `nnef()`, `onnx()`, `runtime_for_name("cpu" | "gpu" | "gpu-or-cpu" | "cuda" | "metal" | ...)`, plus `Model`, `Runnable`, `State`, `Tensor`, `TDim`, and a `SetSymbols` transform builder — with all the backends curated behind it. `impl_ndarray_interop!()` (0.23.0-dev.5) keeps `ndarray` interop opt-in without leaking an `ndarray` version into the public API. Downstream code that pinned `tract-core` + `tract-onnx` directly can usually drop those deps in favour of `tract = "0.23"` and `use tract::prelude::*;`. Examples are now organised around this facade — see `examples/onnx-mobilenet-v2`, `examples/nnef-mobilenet-v2`, and `examples/causal_llm`. + +# 0.23.0 - 2026-05-1 + +This section lists changes since 0.23.0-dev.5 only; the dev.2…dev.5 sections below cover the rest of the 0.22.x→0.23.0 delta. + +### API — breaking + +- **`concretize_symbols` / `substitute_symbols` renamed to `set_symbols`.** Affects `TypedModel::set_symbols`, the `SetSymbols` transform (was `ConcretizeSymbols`), and the `--transform set_symbols=...` CLI form. No deprecation aliases — call sites must be updated. +- **`default` runtime renamed to `cpu`.** `runtime_for_name("default")` still resolves to the CPU runtime (ad-hoc alias), but `Runtime::name()` returns `"cpu"`. JSON loading configs and `--loading-config-path` payloads that pin `"default"` keep working. +- **`nnef().with_tract_core()` removed.** The `tract_core` extension is opt-out since 0.23.0-pre — call `disable_tract_core()` instead, or just drop the `with_tract_core()?` line. + ### CPU / linalg - **ARM SME backend (ARMv9.2-A).** New `linalg/arm64/sme` module provides SME GEMM (Phase 1) and SME2 GEMV (`sme_mmv_f32_64x1`, Phase 2A) micro-kernels for Apple M4+, Cortex-X4, and other ARMv9.2-A+ chips. Dispatch is gated on a 512-bit streaming vector length at runtime; SME2 assembler detection skips kernels when the assembler lacks support. Force via `TRACT_CPU_AARCH64_KIND=applem` (or `generic` to disable). +- **ARMv9 SVE**: f32/f16 GEMM+GEMV and int8→i32 GEMM+GEMV kernels - **Apple AMX: shape-aware dispatch.** AMX kernel selection is now M/N/K-aware at runtime, yielding 5–43% wins across canary models. - **NEON element-wise kernels.** `HardSwish`, `SiLU`, and `GELU` get dedicated aarch64 NEON kernels wired as single graph ops. - **x86_64.** M-aware kernel picker; AVX-512 GEMM routed to the 16×8 / 32×5 / 32×6 kernels. - **WASM SIMD.** Relaxed-SIMD FMA in all MMM kernels; 32×1 GEMV kernel (8 v128 accumulators, 8-way ILP); vectorised sigmoid and tanh; `rustfft wasm_simd` enabled. Low-accumulator MMM paths recover 8–23% under `+relaxed-simd`; M-band GEMV dispatch woken up (30–37% on small-M). `Executor::RayonGlobal` for `wasm-bindgen-rayon`. +- **WASM**: relaxed-dot int8 fast path + PackedI8K4 + SIMD int8 matmul. - **Multithreaded GEMM.** TLS borrow and sync hoisted out of the per-tile inner loop; 2D chunked dispatch with a small-MMM threshold avoids rayon overhead on small operands. - **im2col.** Contiguous-x fast path for valid (zero-padding, unit stride) convolutions; grouped lazy im2col extended; depthwise convolutions excluded from lazy im2col path; N=1/2/3 zone dispatch for depthwise. - **General.** Same-shape fast path in `BinMiniOp::generic_eval`; `rbytes=96/128` fast paths for mn-major packing. - **EinSum** Fold contiguous same-role axes in standard codegen. - **BLAS / SGemm integration dropped.** +- **Cache-adaptive 2D-blocking** for the single-thread MMM tile walk; per-OS L2-size detection on Linux. ### ONNX - `Resize`: `pytorch_half_pixel` coordinate transformer. - `Reshape` with 0-dims and rank change fixed (issue #2104). +- Support for GroupQueryAttention, MultiHeadAttention, MatMulNBits (4-bit), SkipLayerNormalization, SimplifiedLayerNormalization, BiasGelu / FastGelu / + QuickGelu, LpNormalization, MeanVarianceNormalization, GroupNormalization, RotaryEmbedding, opset-24 Attention, Swish, Mish, Gelu, RMSNormalization. +- LayerNorm: fixed output dtype mismatch with F16 inputs. ### NNEF @@ -26,16 +52,20 @@ ### Pulse for chunked attention layers (experimental) - **`pulse::Blockify` rewrite pass.** Translates block-diagonal multi-time-axis subgraphs into chunk-parallel form: recognises quadratic sections (EinSum terminators, `DiagGather` initiators, banded masks, Softmax body chains) and rewrites each into a per-chunk section. Covered by ex01–ex10 synthetic harness cases. -- **`DiagGather` op** (moved from `transformers` into `core`): causal skew-trick gather with ROI-driven narrowing and re-anchoring. +- **`DiagGather` op**: causal skew-trick gather with ROI-driven narrowing and re-anchoring. - **`WindowOnAxis` op**: windowed gather over the streaming axis with configurable pad value. - **`AxisOp::Reshape` pulsifier**: auto-inserts alignment `Delay` on streaming-axis size change. - Stream-axis LCM merge + slope-based per-pulse sizing for `Range`. +- **Scan body state reused across iterations** instead of reallocated per step. +- **`scaled_masked_softmax`** gains a `bool`-mask variant and a `post_softmax_mask` variant on both cuda and metal. +- **`GpuPulsePad`**: stride-aware initial copy fixes 26% drift on pulsified encoders where a fused move axis fed a non-contiguous view to the pad. ### Runtime / plan - Per-node shape resolve skipped once all symbols are bound. - `TDim::Sym` fast-path in shape resolve; lock-free `guess_scenario` on empty scope. - `PropagateRoi` iterates to fixed point and simplifies. +- **Virtual runtime names `gpu` and `gpu-or-cpu`.** `runtime_for_name("gpu")` returns the first available GPU backend (cuda or metal) or errors; `"gpu-or-cpu"` falls back to CPU if none is present. ### Scan @@ -53,7 +83,17 @@ - `doc/symbolic-shapes.md`: TDim, Symbol, and how to bind them. - `doc/op.md`: working with a `Tensor`'s data. - `doc/cli-recipe.md`: `--audit-json`, `--save-outputs`, timing pitfalls, environment-variable table. -- `README.md`: refreshed — current backends, modern examples table, Python bindings section, torch-to-nnef pointer. +- `README.md`: refreshed — current runtimes, modern examples table, Python bindings section, torch-to-nnef pointer. + +### Supply chain / build / CI + +- **`Cargo.lock` tracked.** All workspace + binary builds are now reproducible against the same dependency snapshot. +- **SBOMs on release binaries (CycloneDX + SPDX).** `tract-cli` release binaries are built with `cargo-auditable` (Rust dep tree embedded in the `.dep-v0` section) and shipped alongside CDX + SPDX SBOMs generated by `syft`. Both SBOMs are signed via GitHub attestations (`actions/attest-sbom` + `actions/attest-build-provenance`). +- **PEP 770 SBOMs in Python wheels.** Wheels are built with `cargo-auditable` and have `sbom.cdx.json` + `sbom.spdx.json` injected into `.dist-info/sboms/` per PEP 770. +- **Release builds pinned to current stable rustc** via `dtolnay/rust-toolchain@stable`. +- **`cargo-deny` lints wired up for `tract-cli`.** +- **zizmor SARIF upload** to GitHub's security tab. + # 0.23.0-dev.5 - 2026-04-22 @@ -146,7 +186,7 @@ - **`into_tract()` renamed to `into_model()`** in all API layers. - **`DatumType` variant names shortened** — the `TRACT_DATUM_TYPE_` prefix is dropped (C API). - **Deprecated state methods removed**: `init_states()`, `state_initializers`, and the `n_states` parameter are gone from `State` trait and `RunTensors`. -- **Python**: `concretize_symbols` and `pulse` methods replaced by typed transform classes; `TransformSpec` is now an abstract base class. +- **Python**: `set_symbols` and `pulse` methods replaced by typed transform classes; `TransformSpec` is now an abstract base class. ### Improvements diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000000..bb222f4f59 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,59 @@ +# CLAUDE.md — tract contributor rules (read fully; this is auto-loaded) + +tract is Sonos' Rust NN inference engine. Full guide: AGENTS.md. Architecture/ +reference: doc/. This file is the rules an agent must follow to contribute cleanly. + +## Before you commit +- Format with stable rustfmt: `cargo fmt --all`. The repo's `rust-toolchain.toml` + pins the stable channel, so bare `cargo fmt` picks the same rustfmt CI checks + against — don't override the toolchain. Metal files too, on Linux. +- `cargo clippy --workspace` clean. + +## Commit messages +- One short paragraph: what was wrong + the fix. Nothing else. +- No consequence chains ("X broke Y broke Z"), no "Result:/Symptom:" sections, + no bullet lists of every place the bug surfaced. + +## Inline comments +- Default to NONE. Names carry the meaning. A comment signals a hidden + constraint / invariant / workaround — not narration. +- Never describe the diff or history ("used to be X", "previously…"). Comments + describe current code only. +- No section-banner comments; split into functions instead. + +## Doc comments (`///` / `//!`) +- DO add a concise one on public / non-trivial items — ops, declutter & codegen + passes, public fns. State what it is, its contract, valid inputs, and which + rules it interacts with. This is the one place to be more generous than before. +- Same anti-narration rule: document the *current contract*, not benchmarks, + perf numbers, issue numbers, or history ("Measured on…", "Regression:…"). + +## How to change a model +- Use `TypedModelPatch` / `Rewriter` / `ModelTransform`. Do NOT hand-roll + model-walk loops or rebuild a fresh TypedModel. +- Don't touch `pulse` / `pulse-opl` casually — subtle streaming invariants. + +## Inspecting a model +- To inspect the op graph programmatically, use `tract [--cuda|--metal] + dump --audit-json` (JSON node list to stdout) rather than scraping the colored + `dump` output. Handy for checking which ops landed on which backend. + +## Public API +- The public surface is `api/rs/src/lib.rs`. Check there, not internal `pub` + items. Apps/examples/bindings use `api/rs` only. + +## Tests +- Add op tests to the `suite-*` crates; add synthetic NNEF cases under + `harness/nnef-test-cases/` (driven by `runme.sh` + `--assert-output-bundle`). + If the CLI can't express the assertion, extend the CLI. +- No new Rust integration tests for the above; no mocking internals — prefer + real model round-trips. + +## Idioms / avoid +- No new `unsafe` outside linalg kernels without explicit permission. + No abstraction beyond the task — three similar lines beat a premature helper. + +## Pull requests +- Open with a 1–2 sentence summary of what and why. +- Follow-up questions/review replies are handled by a HUMAN, not the bot. The + maintainer wants to talk to the author, not prompt an LLM. diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000000..4b694f50a2 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,6001 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "adler2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" + +[[package]] +name = "ahash" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" +dependencies = [ + "cfg-if 1.0.4", + "getrandom 0.3.4", + "once_cell", + "serde", + "version_check", + "zerocopy", +] + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "aligned" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee4508988c62edf04abd8d92897fca0c2995d907ce1dfeaf369dac3716a40685" +dependencies = [ + "as-slice", +] + +[[package]] +name = "aligned-vec" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc890384c8602f339876ded803c97ad529f3842aba97f6392b3dba0dd171769b" +dependencies = [ + "equator", +] + +[[package]] +name = "alloca" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5a7d05ea6aea7e9e64d25b9156ba2fee3fdd659e34e41063cd2fc7cd020d7f4" +dependencies = [ + "cc", +] + +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + +[[package]] +name = "alsa" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "812947049edcd670a82cd5c73c3661d2e58468577ba8489de58e1a73c04cbd5d" +dependencies = [ + "alsa-sys", + "bitflags 2.11.1", + "cfg-if 1.0.4", + "libc", +] + +[[package]] +name = "alsa-sys" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad7569085a265dd3f607ebecce7458eaab2132a84393534c95b18dcbc3f31e04" +dependencies = [ + "libc", + "pkg-config", +] + +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + +[[package]] +name = "anstream" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "824a212faf96e9acacdbd09febd34438f8f711fb84e09a8916013cd7815ca28d" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000" + +[[package]] +name = "anstyle-parse" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52ce7f38b242319f7cabaa6813055467063ecdc9d355bbb4ce0c68908cd8130e" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" +dependencies = [ + "anstyle", + "once_cell_polyfill", + "windows-sys 0.61.2", +] + +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + +[[package]] +name = "anymap3" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "170433209e817da6aae2c51aa0dd443009a613425dd041ebfb2492d1c4c11a25" + +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + +[[package]] +name = "arbitrary" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3d036a3c4ab069c7b410a2ce876bd74808d2d0888a82667669f8e783a898bf1" +dependencies = [ + "derive_arbitrary", +] + +[[package]] +name = "arg_enum_proc_macro" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + +[[package]] +name = "as-slice" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "516b6b4f0e40d50dcda9365d53964ec74560ad4284da2e7fc97122cd83174516" +dependencies = [ + "stable_deref_trait", +] + +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + +[[package]] +name = "autocfg" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2032f911046de80f0a198e0901378627c33f59ea0ac00e363d481118bd70a53" + +[[package]] +name = "av-scenechange" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f321d77c20e19b92c39e7471cf986812cbb46659d2af674adc4331ef3f18394" +dependencies = [ + "aligned", + "anyhow", + "arg_enum_proc_macro", + "arrayvec", + "log", + "num-rational", + "num-traits", + "pastey 0.1.1", + "rayon", + "thiserror 2.0.18", + "v_frame", + "y4m", +] + +[[package]] +name = "av1-grain" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cfddb07216410377231960af4fcab838eaa12e013417781b78bd95ee22077f8" +dependencies = [ + "anyhow", + "arrayvec", + "log", + "nom 8.0.0", + "num-rational", + "v_frame", +] + +[[package]] +name = "avif-serialize" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7178fe5f7d460b13895ebb9dcb28a3a6216d2df2574a0806cb51b555d297f38" +dependencies = [ + "arrayvec", +] + +[[package]] +name = "axum" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31b698c5f9a010f6573133b09e0de5408834d0c82f8d7475a89fc1867a71cd90" +dependencies = [ + "axum-core", + "bytes", + "form_urlencoded", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-macros" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7aa268c23bfbbd2c4363b9cd302a4f504fb2a9dfe7e3451d66f35dd392e20aca" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "bindgen" +version = "0.65.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfdf7b466f9a4903edc73f95d6d2bcd5baf8ae620638762244d3f60143643cc5" +dependencies = [ + "bitflags 1.3.2", + "cexpr", + "clang-sys", + "lazy_static", + "lazycell", + "log", + "peeking_take_while", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash 1.1.0", + "shlex", + "syn", + "which", +] + +[[package]] +name = "bindgen" +version = "0.72.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895" +dependencies = [ + "bitflags 2.11.1", + "cexpr", + "clang-sys", + "itertools 0.13.0", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash 2.1.2", + "shlex", + "syn", +] + +[[package]] +name = "bit-set" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" +dependencies = [ + "bit-vec 0.8.0", +] + +[[package]] +name = "bit-set" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2f926cc3060f09db9ebc5b52823d85268d24bb917e472c0c4bea35780a7d" +dependencies = [ + "bit-vec 0.9.1", +] + +[[package]] +name = "bit-vec" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" + +[[package]] +name = "bit-vec" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b71798fca2c1fe1086445a7258a4bc81e6e49dcd24c8d0dd9a1e57395b603f51" +dependencies = [ + "serde", +] + +[[package]] +name = "bit_field" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e4b40c7323adcfc0a41c4b88143ed58346ff65a288fc144329c5c45e05d70c6" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bitflags" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3" +dependencies = [ + "serde_core", +] + +[[package]] +name = "bitstream-io" +version = "4.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7eff00be299a18769011411c9def0d827e8f2d7bf0c3dbf53633147a8867fd1f" +dependencies = [ + "no_std_io2", +] + +[[package]] +name = "block" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a" + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "block2" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdeb9d870516001442e364c5220d3574d2da8dc765554b4a617230d33fa58ef5" +dependencies = [ + "objc2", +] + +[[package]] +name = "boow" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f4505c91a2ef58b8ff2b8c579cd580e4f94949828b4f9a888666cecf08d4124" +dependencies = [ + "cfg-if 0.1.10", +] + +[[package]] +name = "box_drawing" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea27d8d5fd867b17523bf6788b1175fa9867f34669d057e9adaf76e27bcea44b" + +[[package]] +name = "built" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c0e531d93d39c34eef561e929e8a7f86d77a5af08aac4f6d6e39976c51858e9" + +[[package]] +name = "bumpalo" +version = "3.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72f5acc6cb2ba439de613abc23857ec3d78374d8ed5ac84e9d11336e87da8649" + +[[package]] +name = "bytemuck" +version = "1.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "byteorder-lite" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + +[[package]] +name = "castaway" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a" +dependencies = [ + "rustversion", +] + +[[package]] +name = "causal_llm" +version = "0.1.0" +dependencies = [ + "anyhow", + "axum", + "axum-macros", + "clap", + "env_logger", + "float-ord", + "log", + "ndarray", + "ndarray-npy", + "rand 0.10.1", + "reqwest", + "serde", + "serde_json", + "tokenizers", + "tokio", + "tokio-scoped", + "tract", +] + +[[package]] +name = "cc" +version = "1.2.62" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1dce859f0832a7d088c4f1119888ab94ef4b5d6795d1ce05afb7fe159d79f98" +dependencies = [ + "find-msvc-tools", + "jobserver", + "libc", + "shlex", +] + +[[package]] +name = "cesu8" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" + +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom 7.1.3", +] + +[[package]] +name = "cfg-if" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822" + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "chacha20" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f8d983286843e49675a4b7a2d174efe136dc93a18d69130dd18198a6c167601" +dependencies = [ + "cfg-if 1.0.4", + "cpufeatures 0.3.0", + "rand_core 0.10.1", +] + +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading 0.8.9", +] + +[[package]] +name = "clap" +version = "4.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ddb117e43bbf7dacf0a4190fef4d345b9bad68dfc649cb349e7d17d28428e51" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "714a53001bf66416adb0e2ef5ac857140e7dc3a0c48fb28b2f10762fc4b5069f" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2ce8604710f6733aa641a2b3731eaa1e8b3d9973d5e3565da11800813f997a9" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "clap_lex" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9" + +[[package]] +name = "color_quant" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" + +[[package]] +name = "colorchoice" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" + +[[package]] +name = "colorous" +version = "1.0.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4e18bf7a165bf7028fde98609a0f1e8f7498d762a212598e6c891f6893556ec" + +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "memchr", +] + +[[package]] +name = "compact_str" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dfdd1c2274d9aa354115b09dc9a901d6c5576818cdf70d14cae2bdb47df00ab" +dependencies = [ + "castaway", + "cfg-if 1.0.4", + "itoa", + "rustversion", + "ryu", + "serde", + "static_assertions", +] + +[[package]] +name = "console" +version = "0.16.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d64e8af5551369d19cf50138de61f1c42074ab970f74e99be916646777f8fc87" +dependencies = [ + "encode_unicode", + "libc", + "unicode-width", + "windows-sys 0.61.2", +] + +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "core-graphics-types" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d44a101f213f6c4cdc1853d4b78aef6db6bdfa3468798cc1d9912f4735013eb" +dependencies = [ + "bitflags 2.11.1", + "core-foundation", + "libc", +] + +[[package]] +name = "core-proptest-pulse" +version = "0.20.7-pre" +dependencies = [ + "env_logger", + "log", + "proptest", + "tract-core", + "tract-nnef", + "tract-pulse", +] + +[[package]] +name = "core_affinity" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a034b3a7b624016c6e13f5df875747cc25f884156aad2abd12b6c46797971342" +dependencies = [ + "libc", + "num_cpus", + "winapi", +] + +[[package]] +name = "coreaudio-rs" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d5d7dca3ebcf65a035582c9ad4385371a9d9ee6537474d2a278f4e1e475bb58" +dependencies = [ + "bitflags 2.11.1", + "libc", + "objc2-audio-toolbox", + "objc2-core-audio", + "objc2-core-audio-types", + "objc2-core-foundation", +] + +[[package]] +name = "cpal" +version = "0.17.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8942da362c0f0d895d7cac616263f2f9424edc5687364dfd1d25ef7eba506d7" +dependencies = [ + "alsa", + "coreaudio-rs", + "dasp_sample", + "jni 0.21.1", + "js-sys", + "libc", + "mach2", + "ndk", + "ndk-context", + "num-derive", + "num-traits", + "objc2", + "objc2-audio-toolbox", + "objc2-avf-audio", + "objc2-core-audio", + "objc2-core-audio-types", + "objc2-core-foundation", + "objc2-foundation", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "windows", +] + +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + +[[package]] +name = "cpufeatures" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b2a41393f66f16b0823bb79094d54ac5fbd34ab292ddafb9a0456ac9f87d201" +dependencies = [ + "libc", +] + +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if 1.0.4", +] + +[[package]] +name = "criterion" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "950046b2aa2492f9a536f5f4f9a3de7b9e2476e575e05bd6c333371add4d98f3" +dependencies = [ + "alloca", + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "itertools 0.13.0", + "num-traits", + "oorandom", + "page_size", + "plotters", + "rayon", + "regex", + "serde", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8d80a2f4f5b554395e47b5d8305bc3d27813bacb73493eb1001e8f76dae29ea" +dependencies = [ + "cast", + "itertools 0.13.0", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "crunchy" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" + +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "cudarc" +version = "0.19.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cea5f10a99e025c1b44ae2354c2d8326b25ddbd0baf76bde8e55cfd4018a2cc" +dependencies = [ + "half", + "libloading 0.9.0", +] + +[[package]] +name = "curl" +version = "0.4.49" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79fc3b6dd0b87ba36e565715bf9a2ced221311db47bd18011676f24a6066edbc" +dependencies = [ + "curl-sys", + "libc", + "openssl-probe 0.1.6", + "openssl-sys", + "schannel", + "socket2", + "windows-sys 0.59.0", +] + +[[package]] +name = "curl-sys" +version = "0.4.88+curl-8.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "644816de6547255eff4e491a1dda1c19b7237f00b62a61e6e64859ce4f2906d0" +dependencies = [ + "cc", + "libc", + "libz-sys", + "openssl-sys", + "pkg-config", + "vcpkg", + "windows-sys 0.61.2", +] + +[[package]] +name = "daachorse" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f55d7153ba3b507595872a3874803f07a8a81d1e888abed8e5db7da0597d6e2" + +[[package]] +name = "darling" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn", +] + +[[package]] +name = "darling_macro" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" +dependencies = [ + "darling_core", + "quote", + "syn", +] + +[[package]] +name = "dary_heap" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b1e3a325bc115f096c8b77bbf027a7c2592230e70be2d985be950d3d5e60ebe" +dependencies = [ + "serde", +] + +[[package]] +name = "dasp_sample" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c87e182de0887fd5361989c677c4e8f5000cd9491d6d563161a8f3a5519fc7f" + +[[package]] +name = "derive-new" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "derive_arbitrary" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e567bd82dcff979e4b03460c307b3cdc9e96fde3d73bed1496d2bc75d9dd62a" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "derive_builder" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "derive_builder_macro" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" +dependencies = [ + "derive_builder_core", + "syn", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + +[[package]] +name = "dinghy-test" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "346629f4ca872d211748429b0f84f55a5540f25eae129b9eef5bb6ad750e56af" + +[[package]] +name = "dirs" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3e8aa94d75141228480295a7d0e7feb620b1a5ad9f12bc40be62411e38cce4e" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e01a3366d27ee9890022452ee61b2b63a67e6f13f58900b651ff5665f0bb1fab" +dependencies = [ + "libc", + "option-ext", + "redox_users", + "windows-sys 0.61.2", +] + +[[package]] +name = "dispatch2" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e0e367e4e7da84520dedcac1901e4da967309406d1e51017ae1abfb97adbd38" +dependencies = [ + "bitflags 2.11.1", + "objc2", +] + +[[package]] +name = "displaydoc" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ac70aa55017e108007fbaf5aa0f54b021c98f92ff8af59d42eda9da96e3dd4f" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "downcast-rs" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "117240f60069e65410b3ae1bb213295bd828f707b5bec6596a1afc8793ce0cbc" + +[[package]] +name = "dyn-clone" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" + +[[package]] +name = "dyn-eq" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c2d035d21af5cde1a6f5c7b444a5bf963520a9f142e5d06931178433d7d5388" + +[[package]] +name = "dyn-hash" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fdab65db9274e0168143841eb8f864a0a21f8b1b8d2ba6812bbe6024346e99e" + +[[package]] +name = "either" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91622ff5e7162018101f2fea40d6ebf4a78bbe5a49736a2020649edf9693679e" + +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + +[[package]] +name = "env_filter" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32e90c2accc4b07a8456ea0debdc2e7587bdd890680d71173a15d4ae604f6eef" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "env_logger" +version = "0.11.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0621c04f2196ac3f488dd583365b9c09be011a4ab8b9f37248ffcc8f6198b56a" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "jiff", + "log", +] + +[[package]] +name = "equator" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4711b213838dfee0117e3be6ac926007d7f433d7bbe33595975d4190cb07e6fc" +dependencies = [ + "equator-macro", +] + +[[package]] +name = "equator-macro" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44f23cf4b44bfce11a86ace86f8a73ffdec849c9fd00a386a53d278bd9e81fb3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "erased-serde" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2add8a07dd6a8d93ff627029c51de145e12686fbc36ecb298ac22e74cf02dec" +dependencies = [ + "serde", + "serde_core", + "typeid", +] + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "esaxx-rs" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6" +dependencies = [ + "cc", +] + +[[package]] +name = "example-dump-nnef-mobilenet-v2" +version = "0.20.7-pre" +dependencies = [ + "anyhow", + "image", + "ndarray", + "tract", +] + +[[package]] +name = "example-nnef-mobilenet-v2" +version = "0.20.7-pre" +dependencies = [ + "anyhow", + "image", + "ndarray", + "tract", +] + +[[package]] +name = "example-onnx-mobilenet-v2" +version = "0.20.7-pre" +dependencies = [ + "anyhow", + "image", + "ndarray", + "tract", +] + +[[package]] +name = "example-pytorch-resnet" +version = "0.20.7-pre" +dependencies = [ + "anyhow", + "image", + "ndarray", + "tract", +] + +[[package]] +name = "example-tensorflow-mobilenet-v2" +version = "0.20.7-pre" +dependencies = [ + "image", + "tract-tensorflow", +] + +[[package]] +name = "example-tflite-mobilenet-v3" +version = "0.20.7-pre" +dependencies = [ + "image", + "tract-tflite", +] + +[[package]] +name = "exr" +version = "1.74.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4300e043a56aa2cb633c01af81ca8f699a321879a7854d3896a0ba89056363be" +dependencies = [ + "bit_field", + "half", + "lebe", + "miniz_oxide", + "rayon-core", + "smallvec", + "zune-inflate", +] + +[[package]] +name = "face_detection_yolov8onnx_example" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap", + "image", + "ndarray", + "tract", +] + +[[package]] +name = "face_similarity_arcface_onnx" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap", + "image", + "ndarray", + "tract", +] + +[[package]] +name = "fancy-regex" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72cf461f865c862bb7dc573f643dd6a2b6842f7c30b07882b56bd148cc2761b8" +dependencies = [ + "bit-set 0.8.0", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "fastrand" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6" + +[[package]] +name = "fax" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "caf1079563223d5d59d83c85886a56e586cfd5c1a26292e971a0fa266531ac5a" + +[[package]] +name = "fdeflate" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e6853b52649d4ac5c0bd02320cddc5ba956bdb407c4b75a2c6b75bf51500f8c" +dependencies = [ + "simd-adler32", +] + +[[package]] +name = "filetime" +version = "0.2.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c287a33c7f0a620c38e641e7f60827713987b3c0f26e8ddc9462cc69cf75759" +dependencies = [ + "cfg-if 1.0.4", + "libc", +] + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "flatbuffers" +version = "25.12.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35f6839d7b3b98adde531effaf34f0c2badc6f4735d26fe74709d8e513a96ef3" +dependencies = [ + "bitflags 2.11.1", + "rustc_version", +] + +[[package]] +name = "flate2" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" +dependencies = [ + "crc32fast", + "miniz_oxide", + "zlib-rs", +] + +[[package]] +name = "float-ord" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ce81f49ae8a0482e4c55ea62ebbd7e5a686af544c00b9d090bba3ff9be97b3d" + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + +[[package]] +name = "foreign-types" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" +dependencies = [ + "foreign-types-macros", + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-macros" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "foreign-types-shared" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa9a19cbb55df58761df49b23516a86d432839add4af60fc256da840f66ed35b" + +[[package]] +name = "form_urlencoded" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "fs-err" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73fde052dbfc920003cfd2c8e2c6e6d4cc7c1091538c3a24226cec0665ab08c0" +dependencies = [ + "autocfg", +] + +[[package]] +name = "fs2" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9564fc758e15025b46aa6643b1b77d047d1a56a1aea6e01002ac0c7026876213" +dependencies = [ + "libc", + "winapi", +] + +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + +[[package]] +name = "futures-channel" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" + +[[package]] +name = "futures-io" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" + +[[package]] +name = "futures-sink" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" + +[[package]] +name = "futures-task" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" + +[[package]] +name = "futures-util" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" +dependencies = [ + "futures-core", + "futures-io", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "slab", +] + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if 1.0.4", + "libc", + "wasi", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if 1.0.4", + "js-sys", + "libc", + "r-efi 5.3.0", + "wasip2", + "wasm-bindgen", +] + +[[package]] +name = "getrandom" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +dependencies = [ + "cfg-if 1.0.4", + "js-sys", + "libc", + "r-efi 6.0.0", + "rand_core 0.10.1", + "wasip2", + "wasip3", + "wasm-bindgen", +] + +[[package]] +name = "ggml" +version = "0.2.0-dev" +source = "git+https://github.com/rustformers/llm.git?rev=9376078#9376078c12ea1990bd42e63432656819a056d379" +dependencies = [ + "ggml-sys", + "memmap2 0.5.10", + "thiserror 1.0.69", +] + +[[package]] +name = "ggml-sys" +version = "0.2.0-dev" +source = "git+https://github.com/rustformers/llm.git?rev=9376078#9376078c12ea1990bd42e63432656819a056d379" +dependencies = [ + "cc", +] + +[[package]] +name = "gif" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee8cfcc411d9adbbaba82fb72661cc1bcca13e8bba98b364e62b2dba8f960159" +dependencies = [ + "color_quant", + "weezl", +] + +[[package]] +name = "glob" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" + +[[package]] +name = "half" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" +dependencies = [ + "cfg-if 1.0.4", + "crunchy", + "num-traits", + "rand 0.9.4", + "rand_distr 0.5.1", + "zerocopy", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash 0.1.5", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash 0.2.0", + "serde", + "serde_core", +] + +[[package]] +name = "hashbrown" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed5909b6e89a2db4456e54cd5f673791d7eca6732202bbf2a9cc504fe2f9b84a" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + +[[package]] +name = "home" +version = "0.5.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc627f471c528ff0c4a49e1d5e60450c8f6461dd6d10ba9dcd3a61d3dff7728d" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "hound" +version = "3.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f" + +[[package]] +name = "http" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8be7462df143984c4598a256ef469b251d7d7f9e271135073e78fc535414f3d0" +dependencies = [ + "bytes", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "hyper" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb92f162bf56536459fc83c79b974bb12837acfed43d6bc370a7916d0ae15ecc" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ca68d021ef39cf6463ab54c1d0f5daf03377b70561305bb89a8f83aab66e0f" +dependencies = [ + "http", + "hyper", + "hyper-util", + "rustls", + "tokio", + "tokio-rustls", + "tower-service", +] + +[[package]] +name = "hyper-util" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" +dependencies = [ + "base64 0.22.1", + "bytes", + "futures-channel", + "futures-util", + "http", + "http-body", + "hyper", + "ipnet", + "libc", + "percent-encoding", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", +] + +[[package]] +name = "icu_collections" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2984d1cd16c883d7935b9e07e44071dca8d917fd52ecc02c04d5fa0b5a3f191c" +dependencies = [ + "displaydoc", + "potential_utf", + "utf8_iter", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locale_core" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92219b62b3e2b4d88ac5119f8904c10f8f61bf7e95b640d25ba3075e6cac2c29" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_normalizer" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c56e5ee99d6e3d33bd91c5d85458b6005a22140021cc324cea84dd0e72cff3b4" +dependencies = [ + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "utf16_iter", + "utf8_iter", + "write16", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da3be0ae77ea334f4da67c12f149704f19f81d1adf7c51cf482943e84a2bad38" + +[[package]] +name = "icu_properties" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bee3b67d0ea5c2cca5003417989af8996f8604e34fb9ddf96208a033901e70de" +dependencies = [ + "icu_collections", + "icu_locale_core", + "icu_properties_data", + "icu_provider", + "zerotrie", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e2bbb201e0c04f7b4b3e14382af113e17ba4f63e2c9d2ee626b720cbce54a14" + +[[package]] +name = "icu_provider" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "139c4cf31c8b5f33d7e199446eff9c1e02decfc2f0eec2c8d71f65befa45b421" +dependencies = [ + "displaydoc", + "icu_locale_core", + "writeable", + "yoke", + "zerofrom", + "zerotrie", + "zerovec", +] + +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + +[[package]] +name = "idna" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb68373c0d6620ef8105e855e7745e18b0d00d3bdb07fb532e434244cdb9a714" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + +[[package]] +name = "image" +version = "0.25.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85ab80394333c02fe689eaf900ab500fbd0c2213da414687ebf995a65d5a6104" +dependencies = [ + "bytemuck", + "byteorder-lite", + "color_quant", + "exr", + "gif", + "image-webp", + "moxcms", + "num-traits", + "png", + "qoi", + "ravif", + "rayon", + "rgb", + "tiff", + "zune-core", + "zune-jpeg", +] + +[[package]] +name = "image-webp" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525e9ff3e1a4be2fbea1fdf0e98686a6d98b4d8f937e1bf7402245af1909e8c3" +dependencies = [ + "byteorder-lite", + "quick-error 2.0.1", +] + +[[package]] +name = "imgref" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40fac9d56ed6437b198fddba683305e8e2d651aa42647f00f5ae542e7f5c94a2" + +[[package]] +name = "indexmap" +version = "2.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" +dependencies = [ + "equivalent", + "hashbrown 0.17.1", + "serde", + "serde_core", +] + +[[package]] +name = "indicatif" +version = "0.18.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25470f23803092da7d239834776d653104d551bc4d7eacaf31e6837854b8e9eb" +dependencies = [ + "console", + "portable-atomic", + "unicode-width", + "unit-prefix", + "web-time", +] + +[[package]] +name = "infra" +version = "0.1.0" +dependencies = [ + "anyhow", + "downcast-rs", + "dyn-clone", + "env_logger", + "itertools 0.14.0", + "lazy_static", + "proptest", + "tract-core", +] + +[[package]] +name = "interpolate_name" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "inventory" +version = "0.3.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4f0c30c76f2f4ccee3fe55a2435f691ca00c0e4bd87abe4f4a851b1d4dac39b" +dependencies = [ + "rustversion", +] + +[[package]] +name = "ipnet" +version = "2.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" + +[[package]] +name = "is_terminal_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" + +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" + +[[package]] +name = "jiff" +version = "0.2.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "392c70591e8749fe235ddaf513e6f58b26bce3dcc16524cecc8936f75afa161e" +dependencies = [ + "jiff-static", + "log", + "portable-atomic", + "portable-atomic-util", + "serde_core", +] + +[[package]] +name = "jiff-static" +version = "0.2.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b605b0c050d845fc355bb11eb3f9a8deddc218ea60c76e61aa1f2adfb2c96a" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "jni" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97" +dependencies = [ + "cesu8", + "cfg-if 1.0.4", + "combine", + "jni-sys 0.3.1", + "log", + "thiserror 1.0.69", + "walkdir", + "windows-sys 0.45.0", +] + +[[package]] +name = "jni" +version = "0.22.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5efd9a482cf3a427f00d6b35f14332adc7902ce91efb778580e180ff90fa3498" +dependencies = [ + "cfg-if 1.0.4", + "combine", + "jni-macros", + "jni-sys 0.4.1", + "log", + "simd_cesu8", + "thiserror 2.0.18", + "walkdir", + "windows-link", +] + +[[package]] +name = "jni-macros" +version = "0.22.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a00109accc170f0bdb141fed3e393c565b6f5e072365c3bd58f5b062591560a3" +dependencies = [ + "proc-macro2", + "quote", + "rustc_version", + "simd_cesu8", + "syn", +] + +[[package]] +name = "jni-sys" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41a652e1f9b6e0275df1f15b32661cf0d4b78d4d87ddec5e0c3c20f097433258" +dependencies = [ + "jni-sys 0.4.1", +] + +[[package]] +name = "jni-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6377a88cb3910bee9b0fa88d4f42e1d2da8e79915598f65fb0c7ee14c878af2" +dependencies = [ + "jni-sys-macros", +] + +[[package]] +name = "jni-sys-macros" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264" +dependencies = [ + "quote", + "syn", +] + +[[package]] +name = "jobserver" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +dependencies = [ + "getrandom 0.3.4", + "libc", +] + +[[package]] +name = "js-sys" +version = "0.3.99" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "142bc4740e452c1e57ade0cbc129f139c9093e354346f0872ef985f4f5cf5f11" +dependencies = [ + "cfg-if 1.0.4", + "futures-util", + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "kdam" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d847be338ef16a13f97637c062d97fb52ebe0ff3b77fa18456d5ed366317e4f7" +dependencies = [ + "terminal_size", + "windows-sys 0.61.2", +] + +[[package]] +name = "keras-tract-tf2" +version = "0.20.7-pre" +dependencies = [ + "anyhow", + "ndarray", + "ndarray-npy", + "tract", +] + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + +[[package]] +name = "lebe" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a79a3332a6609480d7d0c9eab957bca6b455b91bb84e66d19f5ff66294b85b8" + +[[package]] +name = "libc" +version = "0.2.186" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" + +[[package]] +name = "libfuzzer-sys" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f12a681b7dd8ce12bff52488013ba614b869148d54dd79836ab85aafdd53f08d" +dependencies = [ + "arbitrary", + "cc", +] + +[[package]] +name = "libloading" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" +dependencies = [ + "cfg-if 1.0.4", + "windows-link", +] + +[[package]] +name = "libloading" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "754ca22de805bb5744484a5b151a9e1a8e837d5dc232c2d7d8c2e3492edc8b60" +dependencies = [ + "cfg-if 1.0.4", + "windows-link", +] + +[[package]] +name = "libm" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" + +[[package]] +name = "libredox" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f02ab6bace2054fb888a3c16f990117b579d14a3088e472d63c6011fa185c9d3" +dependencies = [ + "libc", +] + +[[package]] +name = "libz-sys" +version = "1.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc3a226e576f50782b3305c5ccf458698f92798987f551c6a02efe8276721e22" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "linux-raw-sys" +version = "0.4.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" + +[[package]] +name = "linux-raw-sys" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" + +[[package]] +name = "litemap" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92daf443525c4cce67b150400bc2316076100ce0b3686209eb8cf3c31612e6f0" + +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "616ec5685824bcc94416c6d4a7a446eea774a31efd7062c8480ba6fd06d7a6e5" + +[[package]] +name = "loop9" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fae87c125b03c1d2c0150c90365d7d6bcc53fb73a9acaef207d2d065860f062" +dependencies = [ + "imgref", +] + +[[package]] +name = "mach2" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a1b95cd5421ec55b445b5ae102f5ea0e768de1f82bd3001e11f426c269c3aea" +dependencies = [ + "libc", +] + +[[package]] +name = "macro_rules_attribute" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65049d7923698040cd0b1ddcced9b0eb14dd22c5f86ae59c3740eab64a676520" +dependencies = [ + "macro_rules_attribute-proc_macro", + "paste", +] + +[[package]] +name = "macro_rules_attribute-proc_macro" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30" + +[[package]] +name = "malloc_buf" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62bb907fe88d54d8d9ce32a3cceab4218ed2f6b7d35617cafe9adf84e43919cb" +dependencies = [ + "libc", +] + +[[package]] +name = "maplit" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" + +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "maybe-rayon" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ea1f30cedd69f0a2954655f7188c6a834246d2bcf1e315e2ac40c4b24dc9519" +dependencies = [ + "cfg-if 1.0.4", + "rayon", +] + +[[package]] +name = "memchr" +version = "2.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b947ae49db0d222b1dbc6b113ce7248a3fc3a6ca21b696717bfc000ba4484d8" + +[[package]] +name = "memmap2" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83faa42c0a078c393f6b29d5db232d8be22776a891f8f56e5284faee4a20b327" +dependencies = [ + "libc", +] + +[[package]] +name = "memmap2" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "714098028fe011992e1c3962653c96b2d578c4b4bce9036e15ff220319b1e0e3" +dependencies = [ + "libc", +] + +[[package]] +name = "memo-map" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b" + +[[package]] +name = "metal" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7047791b5bc903b8cd963014b355f71dc9864a9a0b727057676c1dcae5cbc15" +dependencies = [ + "bitflags 2.11.1", + "block", + "core-graphics-types", + "foreign-types", + "log", + "objc", + "paste", +] + +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + +[[package]] +name = "minijinja" +version = "2.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2929e494b2280e1e18959bb2e121da03347ae896896fdfaceaab43c88a02803f" +dependencies = [ + "memo-map", + "serde", +] + +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + +[[package]] +name = "miniz_oxide" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" +dependencies = [ + "adler2", + "simd-adler32", +] + +[[package]] +name = "mio" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.61.2", +] + +[[package]] +name = "monostate" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3341a273f6c9d5bef1908f17b7267bbab0e95c9bf69a0d4dcf8e9e1b2c76ef67" +dependencies = [ + "monostate-impl", + "serde", + "serde_core", +] + +[[package]] +name = "monostate-impl" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4db6d5580af57bf992f59068d4ea26fd518574ff48d7639b255a36f9de6e7e9" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "moxcms" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb85c154ba489f01b25c0d36ae69a87e4a1c73a72631fc6c0eb6dde34a73e44b" +dependencies = [ + "num-traits", + "pxfm", +] + +[[package]] +name = "ndarray" +version = "0.17.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520080814a7a6b4a6e9070823bb24b4531daac8c4627e08ba5de8c5ef2f2752d" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + +[[package]] +name = "ndarray-npy" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "58e8a348bca0075000d999d750420d74434fd0d3e0993b456554f885e7657a11" +dependencies = [ + "byteorder", + "ndarray", + "num-complex", + "num-traits", + "py_literal", + "zip", +] + +[[package]] +name = "ndk" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3f42e7bbe13d351b6bead8286a43aac9534b82bd3cc43e47037f012ebfd62d4" +dependencies = [ + "bitflags 2.11.1", + "jni-sys 0.3.1", + "log", + "ndk-sys", + "num_enum", + "thiserror 1.0.69", +] + +[[package]] +name = "ndk-context" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27b02d87554356db9e9a873add8782d4ea6e3e58ea071a9adb9a2e8ddb884a8b" + +[[package]] +name = "ndk-sys" +version = "0.6.0+11769913" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee6cda3051665f1fb8d9e08fc35c96d5a244fb1be711a03b71118828afc9a873" +dependencies = [ + "jni-sys 0.3.1", +] + +[[package]] +name = "nemo-nemotron-asr" +version = "0.1.0" +dependencies = [ + "anyhow", + "float-ord", + "hound", + "itertools 0.14.0", + "ndarray", + "serde_json", + "tract", +] + +[[package]] +name = "nemo-nemotron-streaming-asr" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap", + "cpal", + "float-ord", + "hound", + "itertools 0.14.0", + "ndarray", + "serde_json", + "tract", +] + +[[package]] +name = "nemo-parakeet-asr" +version = "0.1.0" +dependencies = [ + "anyhow", + "float-ord", + "hound", + "itertools 0.14.0", + "ndarray", + "serde_json", + "tract", +] + +[[package]] +name = "new_debug_unreachable" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" + +[[package]] +name = "nnef-inceptionv3" +version = "0.20.7-pre" +dependencies = [ + "dinghy-test", + "env_logger", + "flate2", + "image", + "log", + "tract-core", + "tract-nnef", +] + +[[package]] +name = "no_std_io2" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "418abd1b6d34fbf6cae440dc874771b0525a604428704c76e48b29a5e67b8003" +dependencies = [ + "memchr", +] + +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + +[[package]] +name = "nom" +version = "8.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df9761775871bdef83bee530e60050f7e54b1105350d6884eb0fb4f46c2f9405" +dependencies = [ + "memchr", +] + +[[package]] +name = "nom-language" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2de2bc5b451bfedaef92c90b8939a8fff5770bdcc1fafd6239d086aab8fa6b29" +dependencies = [ + "nom 8.0.0", +] + +[[package]] +name = "noop_proc_macro" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0676bb32a98c1a483ce53e500a81ad9c3d5b3f7c920c28c24e9cb0980d0b5bc8" + +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-derive" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "num_cpus" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" +dependencies = [ + "hermit-abi", + "libc", +] + +[[package]] +name = "num_enum" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d0bca838442ec211fa11de3a8b0e0e8f3a4522575b5c4c06ed722e005036f26" +dependencies = [ + "num_enum_derive", + "rustversion", +] + +[[package]] +name = "num_enum_derive" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "680998035259dcfcafe653688bf2aa6d3e2dc05e98be6ab46afb089dc84f1df8" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "objc" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "915b1b472bc21c53464d6c8461c9d3af805ba1ef837e1cac254428f4a77177b1" +dependencies = [ + "malloc_buf", +] + +[[package]] +name = "objc2" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a12a8ed07aefc768292f076dc3ac8c48f3781c8f2d5851dd3d98950e8c5a89f" +dependencies = [ + "objc2-encode", +] + +[[package]] +name = "objc2-audio-toolbox" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6948501a91121d6399b79abaa33a8aa4ea7857fe019f341b8c23ad6e81b79b08" +dependencies = [ + "bitflags 2.11.1", + "libc", + "objc2", + "objc2-core-audio", + "objc2-core-audio-types", + "objc2-core-foundation", + "objc2-foundation", +] + +[[package]] +name = "objc2-avf-audio" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13a380031deed8e99db00065c45937da434ca987c034e13b87e4441f9e4090be" +dependencies = [ + "objc2", + "objc2-foundation", +] + +[[package]] +name = "objc2-core-audio" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1eebcea8b0dbff5f7c8504f3107c68fc061a3eb44932051c8cf8a68d969c3b2" +dependencies = [ + "dispatch2", + "objc2", + "objc2-core-audio-types", + "objc2-core-foundation", + "objc2-foundation", +] + +[[package]] +name = "objc2-core-audio-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a89f2ec274a0cf4a32642b2991e8b351a404d290da87bb6a9a9d8632490bd1c" +dependencies = [ + "bitflags 2.11.1", + "objc2", +] + +[[package]] +name = "objc2-core-foundation" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a180dd8642fa45cdb7dd721cd4c11b1cadd4929ce112ebd8b9f5803cc79d536" +dependencies = [ + "bitflags 2.11.1", + "block2", + "dispatch2", + "libc", + "objc2", +] + +[[package]] +name = "objc2-encode" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33" + +[[package]] +name = "objc2-foundation" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3e0adef53c21f888deb4fa59fc59f7eb17404926ee8a6f59f5df0fd7f9f3272" +dependencies = [ + "bitflags 2.11.1", + "block2", + "libc", + "objc2", + "objc2-core-foundation", +] + +[[package]] +name = "once_cell" +version = "1.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" + +[[package]] +name = "once_cell_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" + +[[package]] +name = "onig" +version = "6.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc3cbf698f9438986c11a880c90a6d04b9de27575afd28bbf45b154b6c709e2" +dependencies = [ + "bitflags 2.11.1", + "libc", + "once_cell", + "onig_sys", +] + +[[package]] +name = "onig_sys" +version = "69.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e68317604e77e53b85896388e1a803c1d21b74c899ec9e5e1112db90735edd7" +dependencies = [ + "cc", + "pkg-config", +] + +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + +[[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + +[[package]] +name = "openssl-probe" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" + +[[package]] +name = "openssl-sys" +version = "0.9.116" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f28a22dc7140cda5f096e5e7724a6962ca81a7f8bfd2979f9b18c11af56318c4" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "option-ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" + +[[package]] +name = "page_size" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30d5b2194ed13191c1999ae0704b7839fb18384fa22e49b57eeaa97d79ce40da" +dependencies = [ + "libc", + "winapi", +] + +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if 1.0.4", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "pastey" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35fb2e5f958ec131621fdd531e9fc186ed768cbe395337403ae56c17a74c68ec" + +[[package]] +name = "pastey" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ee67f1008b1ba2321834326597b8e186293b049a023cdef258527550b9935b4" + +[[package]] +name = "peeking_take_while" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" + +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + +[[package]] +name = "pest" +version = "2.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0848c601009d37dfa3430c4666e147e49cdcf1b92ecd3e63657d8a5f19da662" +dependencies = [ + "memchr", + "ucd-trie", +] + +[[package]] +name = "pest_derive" +version = "2.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11f486f1ea21e6c10ed15d5a7c77165d0ee443402f0780849d1768e7d9d6fe77" +dependencies = [ + "pest", + "pest_generator", +] + +[[package]] +name = "pest_generator" +version = "2.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8040c4647b13b210a963c1ed407c1ff4fdfa01c31d6d2a098218702e6664f94f" +dependencies = [ + "pest", + "pest_meta", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "pest_meta" +version = "2.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89815c69d36021a140146f26659a81d6c2afa33d216d736dd4be5381a7362220" +dependencies = [ + "pest", + "sha2", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + +[[package]] +name = "pkg-config" +version = "0.3.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19f132c84eca552bf34cab8ec81f1c1dcc229b811638f9d283dceabe58c5569e" + +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + +[[package]] +name = "png" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60769b8b31b2a9f263dae2776c37b1b28ae246943cf719eb6946a1db05128a61" +dependencies = [ + "bitflags 2.11.1", + "crc32fast", + "fdeflate", + "flate2", + "miniz_oxide", +] + +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + +[[package]] +name = "portable-atomic-util" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a106d1259c23fac8e543272398ae0e3c0b8d33c88ed73d0cc71b0f1d902618" +dependencies = [ + "portable-atomic", +] + +[[package]] +name = "potential_utf" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0103b1cef7ec0cf76490e969665504990193874ea05c85ff9bab8b911d0a0564" +dependencies = [ + "zerovec", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + +[[package]] +name = "primal-check" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc0d895b311e3af9902528fbb8f928688abbd95872819320517cc24ca6b2bd08" +dependencies = [ + "num-integer", +] + +[[package]] +name = "proc-macro-crate" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e67ba7e9b2b56446f1d419b1d807906278ffa1a658a8a5d8a39dcb1f5a78614f" +dependencies = [ + "toml_edit", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "profiling" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d595e54a326bc53c1c197b32d295e14b169e3cfeaa8dc82b529f947fba6bcf5" +dependencies = [ + "profiling-procmacros", +] + +[[package]] +name = "profiling-procmacros" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4488a4a36b9a4ba6b9334a32a39971f77c1436ec82c38707bce707699cc3bbcb" +dependencies = [ + "quote", + "syn", +] + +[[package]] +name = "proptest" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b45fcc2344c680f5025fe57779faef368840d0bd1f42f216291f0dc4ace4744" +dependencies = [ + "bit-set 0.8.0", + "bit-vec 0.8.0", + "bitflags 2.11.1", + "num-traits", + "rand 0.9.4", + "rand_chacha", + "rand_xorshift", + "regex-syntax", + "rusty-fork", + "tempfile", + "unarray", +] + +[[package]] +name = "prost" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2ea70524a2f82d518bce41317d0fae74151505651af45faf1ffbd6fd33f0568" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-derive" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" +dependencies = [ + "anyhow", + "itertools 0.14.0", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "prost-types" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8991c4cbdb8bc5b11f0b074ffe286c30e523de90fee5ba8132f1399f23cb3dd7" +dependencies = [ + "prost", +] + +[[package]] +name = "pxfm" +version = "0.1.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0c5ccf5294c6ccd63a74f1565028353830a9c2f5eb0c682c355c471726a6e3f" + +[[package]] +name = "py_literal" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "102df7a3d46db9d3891f178dcc826dc270a6746277a9ae6436f8d29fd490a8e1" +dependencies = [ + "num-bigint", + "num-complex", + "num-traits", + "pest", + "pest_derive", +] + +[[package]] +name = "pytorch-albert-v2" +version = "0.20.7-pre" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "clap_builder", + "clap_lex", + "ndarray", + "tokenizers", + "tract", +] + +[[package]] +name = "qoi" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f6d64c71eb498fe9eae14ce4ec935c555749aef511cca85b5568910d6e48001" +dependencies = [ + "bytemuck", +] + +[[package]] +name = "quick-error" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" + +[[package]] +name = "quick-error" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" + +[[package]] +name = "rand" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c5af06bb1b7d3216d91932aed5265164bf384dc89cd6ba05cf59a35f5f76ea" +dependencies = [ + "rand_chacha", + "rand_core 0.9.5", +] + +[[package]] +name = "rand" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2e8e8bcc7961af1fdac401278c6a831614941f6164ee3bf4ce61b7edb162207" +dependencies = [ + "chacha20", + "getrandom 0.4.2", + "rand_core 0.10.1", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.5", +] + +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom 0.3.4", +] + +[[package]] +name = "rand_core" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63b8176103e19a2643978565ca18b50549f6101881c443590420e4dc998a3c69" + +[[package]] +name = "rand_distr" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463" +dependencies = [ + "num-traits", + "rand 0.9.4", +] + +[[package]] +name = "rand_distr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d431c2703ccf129de4d45253c03f49ebb22b97d6ad79ee3ecfc7e3f4862c1d8" +dependencies = [ + "num-traits", + "rand 0.10.1", +] + +[[package]] +name = "rand_xorshift" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "513962919efc330f829edb2535844d1b912b0fbe2ca165d613e4e8788bb05a5a" +dependencies = [ + "rand_core 0.9.5", +] + +[[package]] +name = "rav1e" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43b6dd56e85d9483277cde964fd1bdb0428de4fec5ebba7540995639a21cb32b" +dependencies = [ + "aligned-vec", + "arbitrary", + "arg_enum_proc_macro", + "arrayvec", + "av-scenechange", + "av1-grain", + "bitstream-io", + "built", + "cfg-if 1.0.4", + "interpolate_name", + "itertools 0.14.0", + "libc", + "libfuzzer-sys", + "log", + "maybe-rayon", + "new_debug_unreachable", + "noop_proc_macro", + "num-derive", + "num-traits", + "paste", + "profiling", + "rand 0.9.4", + "rand_chacha", + "simd_helpers", + "thiserror 2.0.18", + "v_frame", + "wasm-bindgen", +] + +[[package]] +name = "ravif" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e52310197d971b0f5be7fe6b57530dcd27beb35c1b013f29d66c1ad73fbbcc45" +dependencies = [ + "avif-serialize", + "imgref", + "loop9", + "quick-error 2.0.1", + "rav1e", + "rayon", + "rgb", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb39b166781f92d482534ef4b4b1b2568f42613b53e5b6c160e24cfbfa30926d" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-cond" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2964d0cf57a3e7a06e8183d14a8b527195c706b7983549cd5462d5aa3747438f" +dependencies = [ + "either", + "itertools 0.14.0", + "rayon", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "readings-probe" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d44528f4328cceb716d2acfa443fcf847ba35feb107dd0e25a3421456127cd33" +dependencies = [ + "lazy_static", + "libc", + "num_cpus", + "thiserror 1.0.69", + "winapi", +] + +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags 2.11.1", +] + +[[package]] +name = "redox_users" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4e608c6638b9c18977b00b475ac1f28d14e84b27d8d42f70e0bf1e3dec127ac" +dependencies = [ + "getrandom 0.2.17", + "libredox", + "thiserror 2.0.18", +] + +[[package]] +name = "regex" +version = "1.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" + +[[package]] +name = "reqwest" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "219c5811de6525e5416c7d5d53bb656d3afdbc6c5af816e0802bcfa42dbdc1c3" +dependencies = [ + "base64 0.22.1", + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-rustls", + "hyper-util", + "js-sys", + "log", + "percent-encoding", + "pin-project-lite", + "rustls", + "rustls-pki-types", + "rustls-platform-verifier", + "sync_wrapper", + "tokio", + "tokio-rustls", + "tower", + "tower-http", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + +[[package]] +name = "rgb" +version = "0.8.53" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b34b781b31e5d73e9fbc8689c70551fd1ade9a19e3e28cfec8580a79290cc4" + +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if 1.0.4", + "getrandom 0.2.17", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + +[[package]] +name = "ron" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4147b952f3f819eca0e99527022f7d6a8d05f111aeb0a62960c74eb283bec8fc" +dependencies = [ + "bitflags 2.11.1", + "once_cell", + "serde", + "serde_derive", + "typeid", + "unicode-ident", +] + +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + +[[package]] +name = "rustc-hash" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe" + +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + +[[package]] +name = "rustfft" +version = "6.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21db5f9893e91f41798c88680037dba611ca6674703c1a18601b01a72c8adb89" +dependencies = [ + "num-complex", + "num-integer", + "num-traits", + "primal-check", + "strength_reduce", + "transpose", +] + +[[package]] +name = "rustix" +version = "0.38.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" +dependencies = [ + "bitflags 2.11.1", + "errno", + "libc", + "linux-raw-sys 0.4.15", + "windows-sys 0.59.0", +] + +[[package]] +name = "rustix" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" +dependencies = [ + "bitflags 2.11.1", + "errno", + "libc", + "linux-raw-sys 0.12.1", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustls" +version = "0.23.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef86cd5876211988985292b91c96a8f2d298df24e75989a43a3c73f2d4d8168b" +dependencies = [ + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-native-certs" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63" +dependencies = [ + "openssl-probe 0.2.1", + "rustls-pki-types", + "schannel", + "security-framework", +] + +[[package]] +name = "rustls-pki-types" +version = "1.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30a7197ae7eb376e574fe940d068c30fe0462554a3ddbe4eca7838e049c937a9" +dependencies = [ + "zeroize", +] + +[[package]] +name = "rustls-platform-verifier" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d1e2536ce4f35f4846aa13bff16bd0ff40157cdb14cc056c7b14ba41233ba0" +dependencies = [ + "core-foundation", + "core-foundation-sys", + "jni 0.22.4", + "log", + "once_cell", + "rustls", + "rustls-native-certs", + "rustls-platform-verifier-android", + "rustls-webpki", + "security-framework", + "security-framework-sys", + "webpki-root-certs", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustls-platform-verifier-android" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" + +[[package]] +name = "rustls-webpki" +version = "0.103.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c429a8649f110dddef65e2a5ad240f747e85f7758a6bccc7e5777bd33f756e" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "rusty-fork" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc6bf79ff24e648f6da1f8d1f011e9cac26491b619e6b9280f2b47f1774e6ee2" +dependencies = [ + "fnv", + "quick-error 1.2.3", + "tempfile", + "wait-timeout", +] + +[[package]] +name = "ryu" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" + +[[package]] +name = "safetensors" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "675656c1eabb620b921efea4f9199f97fc86e36dd6ffd1fbbe48d0f59a4987f5" +dependencies = [ + "hashbrown 0.16.1", + "serde", + "serde_json", +] + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "scan_fmt" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b53b0a5db882a8e2fdaae0a43f7b39e7e9082389e978398bdf223a55b581248" +dependencies = [ + "regex", +] + +[[package]] +name = "schannel" +version = "0.1.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91c1b7e4904c873ef0710c1f407dde2e6287de2bebc1bbbf7d430bb7cbffd939" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "security-framework" +version = "3.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" +dependencies = [ + "bitflags 2.11.1", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "semver" +version = "1.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a7852d02fc848982e0c167ef163aaff9cd91dc640ba85e263cb1ce46fae51cd" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.150" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8014e44b4736ed0538adeecded0fce2a272f22dc9578a7eb6b2d9993c74cfb9" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if 1.0.4", + "cpufeatures 0.2.17", + "digest", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "simd-adler32" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214" + +[[package]] +name = "simd_cesu8" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94f90157bb87cddf702797c5dadfa0be7d266cdf49e22da2fcaa32eff75b2c33" +dependencies = [ + "rustc_version", + "simdutf8", +] + +[[package]] +name = "simd_helpers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95890f873bec569a0362c235787f3aca6e1e887302ba4840839bcc6459c42da6" +dependencies = [ + "quote", +] + +[[package]] +name = "simdutf8" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" + +[[package]] +name = "slab" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "socket2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "spm_precompiled" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326" +dependencies = [ + "base64 0.13.1", + "nom 7.1.3", + "serde", + "unicode-segmentation", +] + +[[package]] +name = "stable-diffusion" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap", + "image", + "kdam", + "ndarray", + "ndarray-npy", + "rand 0.10.1", + "rand_distr 0.6.0", + "tokenizers", + "tract", +] + +[[package]] +name = "stable-diffusion-3" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap", + "image", + "kdam", + "ndarray", + "rand 0.10.1", + "rand_distr 0.6.0", + "tokenizers", + "tract", +] + +[[package]] +name = "stable-diffusion-xl" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap", + "image", + "kdam", + "ndarray", + "rand 0.10.1", + "rand_distr 0.6.0", + "tokenizers", + "tract", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" + +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + +[[package]] +name = "string-interner" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad3df9b59e2eded8d825c7c4363ad339a20fb6bc0b9a4778560f518f59910b15" +dependencies = [ + "hashbrown 0.16.1", + "serde", +] + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "suite-onnx" +version = "0.1.0" +dependencies = [ + "anyhow", + "bytes", + "env_logger", + "fs2", + "infra", + "itertools 0.14.0", + "lazy_static", + "log", + "prost", + "regex", + "tract-core", + "tract-hir", + "tract-onnx", +] + +[[package]] +name = "suite-unit" +version = "0.1.0" +dependencies = [ + "infra", + "proptest", + "tract-core", + "tract-transformers", +] + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" +dependencies = [ + "futures-core", +] + +[[package]] +name = "synstructure" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tar" +version = "0.4.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f6221d9a6003c78398e3b239969f352578258df48c8eb051caadae0015bc840" +dependencies = [ + "filetime", + "libc", + "xattr", +] + +[[package]] +name = "temp-dir" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "016ef9739649996fcc983b9c588fe3d557cf216d4d98503ce1b057ab5a66d689" + +[[package]] +name = "tempfile" +version = "3.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" +dependencies = [ + "fastrand", + "getrandom 0.4.2", + "once_cell", + "rustix 1.1.4", + "windows-sys 0.61.2", +] + +[[package]] +name = "terminal_size" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "230a1b821ccbd75b185820a1f1ff7b14d21da1e442e22c0863ea5f08771a8874" +dependencies = [ + "rustix 1.1.4", + "windows-sys 0.61.2", +] + +[[package]] +name = "test-cuda" +version = "0.1.0" +dependencies = [ + "home", + "infra", + "lazy_static", + "log", + "pastey 0.2.3", + "regex", + "suite-onnx", + "suite-unit", + "tract-core", + "tract-cuda", + "tract-gpu", + "tract-onnx-opl", +] + +[[package]] +name = "test-f16" +version = "0.1.0" +dependencies = [ + "home", + "infra", + "lazy_static", + "log", + "regex", + "suite-onnx", + "suite-unit", + "tract-core", + "tract-nnef", + "tract-onnx-opl", + "tract-transformers", +] + +[[package]] +name = "test-metal" +version = "0.1.0" +dependencies = [ + "home", + "infra", + "lazy_static", + "log", + "pastey 0.2.3", + "regex", + "suite-onnx", + "suite-unit", + "tract-core", + "tract-gpu", + "tract-metal", + "tract-onnx-opl", +] + +[[package]] +name = "test-nnef-cycle" +version = "0.1.0" +dependencies = [ + "infra", + "itertools 0.14.0", + "lazy_static", + "log", + "suite-onnx", + "suite-unit", + "tract-core", + "tract-nnef", + "tract-onnx-opl", + "tract-transformers", +] + +[[package]] +name = "test-onnx-core" +version = "0.20.7-pre" +dependencies = [ + "lazy_static", + "suite-onnx", + "tract-core", + "tract-nnef", + "tract-onnx", +] + +[[package]] +name = "test-tflite" +version = "0.1.0" +dependencies = [ + "home", + "infra", + "lazy_static", + "log", + "regex", + "suite-onnx", + "suite-unit", + "tflitec", + "tract-core", + "tract-onnx-opl", + "tract-tflite", +] + +[[package]] +name = "test-unit-core" +version = "0.1.0" +dependencies = [ + "suite-unit", + "tract-core", +] + +[[package]] +name = "tf-inceptionv3" +version = "0.20.7-pre" +dependencies = [ + "criterion", + "dinghy-test", + "env_logger", + "image", + "log", + "tract-tensorflow", +] + +[[package]] +name = "tf-mobilenet-v2" +version = "0.20.7-pre" +dependencies = [ + "dinghy-test", + "image", + "tract-tensorflow", +] + +[[package]] +name = "tfl-mobilenet-v2-q" +version = "0.20.7-pre" +dependencies = [ + "dinghy-test", + "image", + "tract-tflite", +] + +[[package]] +name = "tflitec" +version = "0.6.0" +source = "git+https://github.com/kali/tflitec-rs.git?rev=9ceb838#9ceb838839d0481030aa12e95d8cb28f659f7a48" +dependencies = [ + "bindgen 0.65.1", + "curl", + "fs_extra", +] + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl 2.0.18", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tiff" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b63feaf3343d35b6ca4d50483f94843803b0f51634937cc2ec519fc32232bc52" +dependencies = [ + "fax", + "flate2", + "half", + "quick-error 2.0.1", + "weezl", + "zune-jpeg", +] + +[[package]] +name = "tinystr" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8323304221c2a851516f22236c5722a72eaa19749016521d6dff0824447d96d" +dependencies = [ + "displaydoc", + "zerovec", +] + +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "tokenizers" +version = "0.23.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44e5bea67576e04b6ff8564c5d9e09c2ef0cf476502245f2f120e497769d3112" +dependencies = [ + "ahash", + "compact_str", + "daachorse", + "dary_heap", + "derive_builder", + "esaxx-rs", + "fancy-regex", + "getrandom 0.3.4", + "indicatif", + "itertools 0.14.0", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand 0.9.4", + "rayon", + "rayon-cond", + "regex", + "regex-syntax", + "serde", + "serde_json", + "spm_precompiled", + "thiserror 2.0.18", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + +[[package]] +name = "tokio" +version = "1.52.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fc7f01b389ac15039e4dc9531aa973a135d7a4135281b12d7c1bc79fd57fffe" +dependencies = [ + "bytes", + "libc", + "mio", + "pin-project-lite", + "socket2", + "tokio-macros", + "windows-sys 0.61.2", +] + +[[package]] +name = "tokio-macros" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "385a6cb71ab9ab790c5fe8d67f1645e6c450a7ce006a33de03daa956cf70a496" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + +[[package]] +name = "tokio-scoped" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4beb8ba13bc53ac53ce1d52b42f02e5d8060f0f42138862869beb769722b256" +dependencies = [ + "tokio", + "tokio-stream", +] + +[[package]] +name = "tokio-stream" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32da49809aab5c3bc678af03902d4ccddea2a87d028d86392a4b1560c6906c70" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "toml_datetime" +version = "1.1.1+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3165f65f62e28e0115a00b2ebdd37eb6f3b641855f9d636d3cd4103767159ad7" +dependencies = [ + "serde_core", +] + +[[package]] +name = "toml_edit" +version = "0.25.12+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2153edc6955a6c354fad8f5efd38b6a8769bdccf9fe50f8e1329f81b0baa5d7" +dependencies = [ + "indexmap", + "toml_datetime", + "toml_parser", + "winnow", +] + +[[package]] +name = "toml_parser" +version = "1.1.2+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2abe9b86193656635d2411dc43050282ca48aa31c2451210f4202550afb7526" +dependencies = [ + "winnow", +] + +[[package]] +name = "tower" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-http" +version = "0.6.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cfcf7e2740e6fc6d4d688b4ef00650406bb94adf4731e43c096c3a19fe40840" +dependencies = [ + "bitflags 2.11.1", + "bytes", + "futures-util", + "http", + "http-body", + "pin-project-lite", + "tower", + "tower-layer", + "tower-service", + "url", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "log", + "pin-project-lite", + "tracing-core", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", +] + +[[package]] +name = "tract" +version = "0.23.1-pre" +dependencies = [ + "anyhow", + "boow", + "erased-serde", + "flate2", + "half", + "icu_normalizer", + "icu_properties", + "reqwest", + "rustls", + "serde_json", + "tempfile", + "tract-api", + "tract-cuda", + "tract-extra", + "tract-libcli", + "tract-metal", + "tract-nnef", + "tract-onnx", + "tract-onnx-opl", + "tract-pulse", + "tract-transformers", +] + +[[package]] +name = "tract-api" +version = "0.23.1-pre" +dependencies = [ + "anyhow", + "boow", + "flate2", + "half", + "lazy_static", + "reqwest", + "serde", + "serde_json", + "tempfile", +] + +[[package]] +name = "tract-cli" +version = "0.23.1-pre" +dependencies = [ + "box_drawing", + "clap", + "colorous", + "criterion", + "cudarc", + "env_logger", + "erased-serde", + "flate2", + "float-ord", + "fs-err", + "icu_normalizer", + "icu_normalizer_data", + "icu_properties", + "icu_properties_data", + "idna_adapter", + "inventory", + "lazy_static", + "litemap", + "log", + "ndarray-npy", + "nu-ansi-term", + "num_cpus", + "py_literal", + "readings-probe", + "regex", + "reqwest", + "ron", + "rustls", + "scan_fmt", + "serde", + "serde_json", + "tract-core", + "tract-cuda", + "tract-extra", + "tract-gpu", + "tract-hir", + "tract-libcli", + "tract-linalg", + "tract-metal", + "tract-nnef", + "tract-nnef-resources", + "tract-onnx", + "tract-pulse", + "tract-pulse-opl", + "tract-tensorflow", + "tract-tflite", + "tract-transformers", + "webpki-roots", + "zerofrom", +] + +[[package]] +name = "tract-core" +version = "0.23.1-pre" +dependencies = [ + "anyhow", + "anymap3", + "approx", + "bit-set 0.10.0", + "criterion", + "derive-new", + "downcast-rs", + "dyn-clone", + "dyn-eq", + "env_logger", + "erased-serde", + "inventory", + "lazy_static", + "log", + "maplit", + "ndarray", + "num-complex", + "num-integer", + "num-traits", + "pastey 0.2.3", + "proptest", + "rustfft", + "serde", + "smallvec", + "tract-data", + "tract-linalg", +] + +[[package]] +name = "tract-cuda" +version = "0.23.1-pre" +dependencies = [ + "anyhow", + "criterion", + "cudarc", + "derive-new", + "dirs", + "downcast-rs", + "dyn-eq", + "inventory", + "libloading 0.9.0", + "log", + "minijinja", + "num-traits", + "proptest", + "rand 0.10.1", + "tract-core", + "tract-gpu", + "tract-pulse-opl", + "tract-transformers", +] + +[[package]] +name = "tract-data" +version = "0.23.1-pre" +dependencies = [ + "anyhow", + "criterion", + "downcast-rs", + "dyn-clone", + "dyn-eq", + "dyn-hash", + "half", + "itertools 0.14.0", + "lazy_static", + "libm", + "maplit", + "ndarray", + "nom 8.0.0", + "nom-language", + "num-complex", + "num-integer", + "num-traits", + "parking_lot", + "proptest", + "scan_fmt", + "smallvec", + "string-interner", +] + +[[package]] +name = "tract-extra" +version = "0.23.1-pre" +dependencies = [ + "approx", + "criterion", + "env_logger", + "lazy_static", + "proptest", + "tract-nnef", + "tract-pulse", +] + +[[package]] +name = "tract-ffi" +version = "0.23.1-pre" +dependencies = [ + "anyhow", + "flate2", + "serde", + "serde_json", + "tract", + "tract-api", +] + +[[package]] +name = "tract-gpu" +version = "0.23.1-pre" +dependencies = [ + "anyhow", + "derive-new", + "downcast-rs", + "dyn-eq", + "dyn-hash", + "num-traits", + "tract-core", + "tract-pulse-opl", + "tract-transformers", +] + +[[package]] +name = "tract-hir" +version = "0.23.1-pre" +dependencies = [ + "derive-new", + "env_logger", + "log", + "tract-core", +] + +[[package]] +name = "tract-libcli" +version = "0.23.1-pre" +dependencies = [ + "box_drawing", + "clap", + "colorous", + "cudarc", + "lazy_static", + "log", + "ndarray-npy", + "nu-ansi-term", + "py_literal", + "rand 0.10.1", + "serde", + "serde_json", + "tract-core", + "tract-cuda", + "tract-gpu", + "tract-hir", + "tract-metal", + "tract-onnx", + "tract-tflite", + "tract-transformers", +] + +[[package]] +name = "tract-linalg" +version = "0.23.1-pre" +dependencies = [ + "byteorder", + "cc", + "core_affinity", + "criterion", + "derive-new", + "downcast-rs", + "dyn-clone", + "dyn-eq", + "dyn-hash", + "env_logger", + "half", + "lazy_static", + "libc", + "log", + "minijinja", + "nu-ansi-term", + "num-traits", + "pastey 0.2.3", + "proptest", + "rayon", + "scan_fmt", + "tract-data", + "walkdir", +] + +[[package]] +name = "tract-metal" +version = "0.23.1-pre" +dependencies = [ + "anyhow", + "criterion", + "derive-new", + "downcast-rs", + "ggml", + "inventory", + "log", + "metal", + "num-traits", + "objc", + "proptest", + "rand 0.10.1", + "tract-core", + "tract-gpu", + "tract-pulse-opl", + "tract-transformers", +] + +[[package]] +name = "tract-nnef" +version = "0.23.1-pre" +dependencies = [ + "byteorder", + "erased-serde", + "flate2", + "log", + "minijinja", + "nom 8.0.0", + "nom-language", + "safetensors", + "serde", + "serde_json", + "simd-adler32", + "tar", + "temp-dir", + "tract-core", + "walkdir", +] + +[[package]] +name = "tract-nnef-cli" +version = "0.21.8-pre" +dependencies = [ + "anyhow", + "clap", + "env_logger", + "log", + "tract-nnef", + "tract-onnx-opl", + "tract-pulse", +] + +[[package]] +name = "tract-nnef-resources" +version = "0.23.1-pre" +dependencies = [ + "anyhow", + "nom 8.0.0", + "nom-language", + "serde", + "serde_json", + "tract-nnef", +] + +[[package]] +name = "tract-onnx" +version = "0.23.1-pre" +dependencies = [ + "bytes", + "criterion", + "derive-new", + "dyn-eq", + "env_logger", + "log", + "memmap2 0.9.10", + "num-integer", + "prost", + "rand 0.10.1", + "rayon", + "smallvec", + "tract-extra", + "tract-hir", + "tract-nnef", + "tract-onnx-opl", + "tract-transformers", +] + +[[package]] +name = "tract-onnx-opl" +version = "0.23.1-pre" +dependencies = [ + "dyn-eq", + "env_logger", + "getrandom 0.4.2", + "log", + "rand 0.10.1", + "rand_distr 0.6.0", + "rustfft", + "tract-extra", + "tract-nnef", +] + +[[package]] +name = "tract-proxy" +version = "0.23.1-pre" +dependencies = [ + "anyhow", + "boow", + "home", + "ndarray", + "reqwest", + "rustls", + "serde_json", + "tempfile", + "tract-api", + "tract-proxy-sys", +] + +[[package]] +name = "tract-proxy-sys" +version = "0.23.1-pre" +dependencies = [ + "bindgen 0.72.1", +] + +[[package]] +name = "tract-pulse" +version = "0.23.1-pre" +dependencies = [ + "downcast-rs", + "dyn-eq", + "erased-serde", + "lazy_static", + "log", + "serde", + "tract-pulse-opl", + "tract-transformers", +] + +[[package]] +name = "tract-pulse-opl" +version = "0.23.1-pre" +dependencies = [ + "downcast-rs", + "dyn-eq", + "lazy_static", + "tract-nnef", +] + +[[package]] +name = "tract-tensorflow" +version = "0.23.1-pre" +dependencies = [ + "bytes", + "criterion", + "derive-new", + "env_logger", + "log", + "memmap2 0.9.10", + "proptest", + "prost", + "prost-types", + "rand 0.10.1", + "tract-hir", + "tract-pulse", +] + +[[package]] +name = "tract-tflite" +version = "0.23.1-pre" +dependencies = [ + "derive-new", + "flatbuffers", + "tract-core", +] + +[[package]] +name = "tract-transformers" +version = "0.23.1-pre" +dependencies = [ + "float-ord", + "rayon", + "tract-nnef", +] + +[[package]] +name = "transpose" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad61aed86bc3faea4300c7aee358b4c6d0c8d6ccc36524c96e4c92ccf26e77e" +dependencies = [ + "num-integer", + "strength_reduce", +] + +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + +[[package]] +name = "typeid" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc7d623258602320d5c55d1bc22793b57daff0ec7efc270ea7d55ce1d5f5471c" + +[[package]] +name = "typenum" +version = "1.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40ce102ab67701b8526c123c1bab5cbe42d7040ccfd0f64af1a385808d2f43de" + +[[package]] +name = "ucd-trie" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" + +[[package]] +name = "unarray" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "unicode-normalization-alignments" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de" +dependencies = [ + "smallvec", +] + +[[package]] +name = "unicode-segmentation" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9629274872b2bfaf8d66f5f15725007f635594914870f65218920345aa11aa8c" + +[[package]] +name = "unicode-width" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + +[[package]] +name = "unicode_categories" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" + +[[package]] +name = "unit-prefix" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81e544489bf3d8ef66c953931f56617f423cd4b5494be343d9b9d3dda037b9a3" + +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "url" +version = "2.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff67a8a4397373c3ef660812acab3268222035010ab8680ec4215f38ba3d0eed" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", + "serde", +] + +[[package]] +name = "utf16_iter" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + +[[package]] +name = "v_frame" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "666b7727c8875d6ab5db9533418d7c764233ac9c0cff1d469aec8fa127597be2" +dependencies = [ + "aligned-vec", + "num-traits", + "wasm-bindgen", +] + +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "wait-timeout" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ac3b126d3914f9849036f826e054cbabdc8519970b8998ddaf3b5bd3c65f11" +dependencies = [ + "libc", +] + +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasip2" +version = "1.0.3+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20064672db26d7cdc89c7798c48a0fdfac8213434a1186e5ef29fd560ae223d6" +dependencies = [ + "wit-bindgen 0.57.1", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen 0.51.0", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.122" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ed04576f974d2b2fba0f38c51dbc5518011e38c36bf1143164be765528fd409" +dependencies = [ + "cfg-if 1.0.4", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.72" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9473dbd2991ae90b6291c3c32c30c6187ac49aa32f9905d1cce280ec1e110b0f" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.122" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "916151b09da36bd82f6615cbf3a419e2f0ba23a03c6160e8e92eb6bd4aa1dec6" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.122" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "299047362ccbfce148b67ab7e73349f77748e00c8296f9542adfad2ad82c5c5e" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.122" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a929b2c61f11ba3e9bc35b50c1f25cb38e0e892c0c231ae2b8cf78d5dad4437" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasm-model-bench" +version = "0.1.0" +dependencies = [ + "anyhow", + "tract-nnef", + "tract-onnx", +] + +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags 2.11.1", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + +[[package]] +name = "web-sys" +version = "0.3.99" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d621441cfc37b84979402712047321980c178f299193a3589d05b99e8763436" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "webpki-root-certs" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31141ce3fc3e300ae89b78c0dd67f9708061d1d2eda54b8209346fd6be9a92c" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "webpki-roots" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52f5ee44c96cf55f1b349600768e3ece3a8f26010c05265ab73f945bb1a2eb9d" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "weezl" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28ac98ddc8b9274cb41bb4d9d4d5c425b6020c50c46f25559911905610b4a88" + +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix 0.38.44", +] + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "527fadee13e0c05939a6a05d5bd6eec6cd2e3dbd648b9f8e447c6518133d8580" +dependencies = [ + "windows-collections", + "windows-core", + "windows-future", + "windows-numerics", +] + +[[package]] +name = "windows-collections" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23b2d95af1a8a14a3c7367e1ed4fc9c20e0a26e79551b1454d72583c97cc6610" +dependencies = [ + "windows-core", +] + +[[package]] +name = "windows-core" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-future" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1d6f90251fe18a279739e78025bd6ddc52a7e22f921070ccdc67dde84c605cb" +dependencies = [ + "windows-core", + "windows-link", + "windows-threading", +] + +[[package]] +name = "windows-implement" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-interface" +version = "0.59.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-numerics" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e2e40844ac143cdb44aead537bbf727de9b044e107a0f1220392177d15b0f26" +dependencies = [ + "windows-core", + "windows-link", +] + +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets 0.42.2", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows-threading" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3949bd5b99cafdf1c7ca86b43ca564028dfe27d66958f2470940f73d86d75b37" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "winnow" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0592e1c9d151f854e6fd382574c3a0855250e1d9b2f99d9281c6e6391af352f1" +dependencies = [ + "memchr", +] + +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen" +version = "0.57.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ebf944e87a7c253233ad6766e082e3cd714b5d03812acc24c318f549614536e" + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags 2.11.1", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + +[[package]] +name = "write16" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" + +[[package]] +name = "writeable" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ffae5123b2d3fc086436f8834ae3ab053a283cfac8fe0a0b8eaae044768a4c4" + +[[package]] +name = "xattr" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32e45ad4206f6d2479085147f02bc2ef834ac85886624a23575ae137c8aa8156" +dependencies = [ + "libc", + "rustix 1.1.4", +] + +[[package]] +name = "y4m" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a5a4b21e1a62b67a2970e6831bc091d7b87e119e7f9791aef9702e3bef04448" + +[[package]] +name = "yoke" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "abe8c5fda708d9ca3df187cae8bfb9ceda00dd96231bed36e445a1a48e66f9ca" +dependencies = [ + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de844c262c8848816172cef550288e7dc6c7b7814b4ee56b3e1553f275f1858e" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zerocopy" +version = "0.8.49" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bce33a6288fa3f072a8c2c7d0f2fdbb90e28298f0135c1f99b96c3db2efcc60b" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.49" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd425244944f4ab65ccff928e7323354c5a018c75838362fdce749dfad2ee1e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zerofrom" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ec05a11813ea801ff6d75110ad09cd0824ddba17dfe17128ea0d5f68e6c5272" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11532158c46691caf0f2593ea8358fed6bbf68a0315e80aae9bd41fbade684a1" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + +[[package]] +name = "zerotrie" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f9152d31db0792fa83f70fb2f83148effb5c1f5b8c7686c3459e361d9bc20bf" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + +[[package]] +name = "zerovec" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90f911cbc359ab6af17377d242225f4d75119aec87ea711a880987b18cd7b239" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "625dc425cab0dca6dc3c3319506e6593dcb08a9f387ea3b284dbd52a92c40555" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zip" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb2a05c7c36fde6c09b08576c9f7fb4cda705990f73b58fe011abf7dfb24168b" +dependencies = [ + "arbitrary", + "crc32fast", + "flate2", + "indexmap", + "memchr", + "zopfli", +] + +[[package]] +name = "zlib-rs" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3be3d40e40a133f9c916ee3f9f4fa2d9d63435b5fbe1bfc6d9dae0aa0ada1513" + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" + +[[package]] +name = "zopfli" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f05cd8797d63865425ff89b5c4a48804f35ba0ce8d125800027ad6017d2b5249" +dependencies = [ + "bumpalo", + "crc32fast", + "log", + "simd-adler32", +] + +[[package]] +name = "zune-core" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb8a0807f7c01457d0379ba880ba6322660448ddebc890ce29bb64da71fb40f9" + +[[package]] +name = "zune-inflate" +version = "0.2.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73ab332fe2f6680068f3582b16a24f90ad7096d5d39b974d1c0aff0125116f02" +dependencies = [ + "simd-adler32", +] + +[[package]] +name = "zune-jpeg" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27bc9d5b815bc103f142aa054f561d9187d191692ec7c2d1e2b4737f8dbd7296" +dependencies = [ + "zune-core", +] diff --git a/Cargo.toml b/Cargo.toml index 81b9e48164..13662d8765 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -116,6 +116,8 @@ default-members = [ ] [workspace.package] +# MSRV source of truth: CI derives the tested toolchain from this. Keep the +# README rustc badge in sync when bumping. rust-version = "1.91" [workspace.dependencies] @@ -125,7 +127,6 @@ anstyle-query = "1.0.0" anyhow = "1.0.43" anymap3 = "1.0" approx = "0.5" -atty = "0.2.14" bit-set = "0.10.0" boow = "0.1.3" box_drawing = "0.1.2" @@ -206,37 +207,36 @@ smallvec = "1.6.1" string-interner = "0.20" tar = "0.4.37" tempfile = "3.8" -tensorflow = "0.21.0" tflitec = { git = "https://github.com/kali/tflitec-rs.git", rev="9ceb838" } time = "0.3.23" tokenizers = "0.23" unicode-normalization = "0.1.19" walkdir = "2.3.2" zerofrom = "0.1.5" -tract-api = { version = "=0.23.0-dev.4", path = 'api' } -tract-core = { version = "0.23.0-pre", path = 'core' } -tract-cuda = { version = "0.23.0-pre", path = 'cuda', default-features = false } -tract-data = { version = "0.23.0-pre", path = 'data' } -tract-extra = { version = "0.23.0-pre", path = 'extra' } -tract-gpu = { version = "0.23.0-pre", path = 'gpu' } -tract-hir = { version = "0.23.0-pre", path = 'hir' } -tract-libcli = { version = "0.23.0-pre", path = 'libcli' } -tract-linalg = { version = "0.23.0-pre", path = 'linalg' } -tract-metal = { version = "0.23.0-pre", path = 'metal' } -tract-nnef-resources = { version = "0.23.0-pre", path = 'nnef/nnef-resources' } -tract-nnef = { version = "0.23.0-pre", path = 'nnef' } -tract-onnx-opl = { version = "0.23.0-pre", path = 'onnx-opl' } -tract-onnx = { version = "0.23.0-pre", path = 'onnx' } -tract-pulse-opl = { version = "0.23.0-pre", path = 'pulse-opl' } -tract-pulse = { version = "0.23.0-pre", path = 'pulse' } -tract-tensorflow = { version = "0.23.0-pre", path = 'tensorflow' } -tract-tflite = { version = "0.23.0-pre", path = 'tflite' } -tract-transformers = { version = "0.23.0-pre", path = 'transformers' } -tract = { version = "0.23.0-pre", path = 'api/rs' } -tract-proxy-sys = { version = "0.23.0-pre", path = 'api/proxy/sys' } -tract-cli = { version = "0.23.0-pre", path = 'cli' } -tract-ffi = { version = "0.23.0-pre" } -tract-proxy = { version = "0.23.0-pre" } +tract-api = { version = "0.23.1-pre", path = 'api' } +tract-core = { version = "0.23.1-pre", path = 'core' } +tract-cuda = { version = "0.23.1-pre", path = 'cuda', default-features = false } +tract-data = { version = "0.23.1-pre", path = 'data' } +tract-extra = { version = "0.23.1-pre", path = 'extra' } +tract-gpu = { version = "0.23.1-pre", path = 'gpu' } +tract-hir = { version = "0.23.1-pre", path = 'hir' } +tract-libcli = { version = "0.23.1-pre", path = 'libcli' } +tract-linalg = { version = "0.23.1-pre", path = 'linalg' } +tract-metal = { version = "0.23.1-pre", path = 'metal' } +tract-nnef-resources = { version = "0.23.1-pre", path = 'nnef/nnef-resources' } +tract-nnef = { version = "0.23.1-pre", path = 'nnef' } +tract-onnx-opl = { version = "0.23.1-pre", path = 'onnx-opl' } +tract-onnx = { version = "0.23.1-pre", path = 'onnx' } +tract-pulse-opl = { version = "0.23.1-pre", path = 'pulse-opl' } +tract-pulse = { version = "0.23.1-pre", path = 'pulse' } +tract-tensorflow = { version = "0.23.1-pre", path = 'tensorflow' } +tract-tflite = { version = "0.23.1-pre", path = 'tflite' } +tract-transformers = { version = "0.23.1-pre", path = 'transformers' } +tract = { version = "0.23.1-pre", path = 'api/rs' } +tract-proxy-sys = { version = "0.23.1-pre", path = 'api/proxy/sys' } +tract-cli = { version = "0.23.1-pre", path = 'cli' } +tract-ffi = { version = "0.23.1-pre" } +tract-proxy = { version = "0.23.1-pre" } [profile.opt-no-lto] diff --git a/api/Cargo.toml b/api/Cargo.toml index d4a868f273..a7c0b1b77e 100644 --- a/api/Cargo.toml +++ b/api/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tract-api" -version = "0.23.0-dev.4" +version = "0.23.1-pre" license = "MIT OR Apache-2.0" authors = ["Mathieu Poumeyrol "] description = "Tiny, no-nonsense, self contained, TensorFlow and ONNX inference" diff --git a/api/ffi/Cargo.toml b/api/ffi/Cargo.toml index f65b5b3698..6f7ce13bb6 100644 --- a/api/ffi/Cargo.toml +++ b/api/ffi/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tract-ffi" -version = "0.23.0-pre" +version = "0.23.1-pre" license = "MIT OR Apache-2.0" authors = ["Mathieu Poumeyrol "] description = "Tiny, no-nonsense, self contained, neural network inference" diff --git a/api/proxy/Cargo.toml b/api/proxy/Cargo.toml index f7300b3415..d3f6cca395 100644 --- a/api/proxy/Cargo.toml +++ b/api/proxy/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tract-proxy" -version = "0.23.0-pre" +version = "0.23.1-pre" license = "MIT OR Apache-2.0" authors = ["Mathieu Poumeyrol "] description = "Tiny, no-nonsense, self contained, TensorFlow and ONNX inference" diff --git a/api/proxy/sys/Cargo.toml b/api/proxy/sys/Cargo.toml index f82a5a9d75..4aea3fad78 100644 --- a/api/proxy/sys/Cargo.toml +++ b/api/proxy/sys/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tract-proxy-sys" -version = "0.23.0-pre" +version = "0.23.1-pre" license = "MIT OR Apache-2.0" authors = ["Mathieu Poumeyrol "] description = "Tiny, no-nonsense, self contained, TensorFlow and ONNX inference" diff --git a/api/py/pyproject.toml b/api/py/pyproject.toml index 7043dd3362..53a36e0545 100644 --- a/api/py/pyproject.toml +++ b/api/py/pyproject.toml @@ -21,13 +21,27 @@ cargo --version || (curl https://sh.rustup.rs -sSf | sh -s -- -y --profile minim . $HOME/.cargo/env rustup toolchain add stable rustup default stable) -[ -e $HOME/.local/bin/sccache ] || ./.travis/setup-sccache.sh ] +[ -e $HOME/.local/bin/sccache ] || ./.travis/setup-sccache.sh ] +# cargo-auditable embeds the resolved Rust dep graph into the .so the +# wheel ships. setuptools_rust honours $CARGO (build.py:97); point it +# at a shim that prefixes `auditable` for build/rustc. +cargo install --locked cargo-auditable --version 0.7.0 +mkdir -p "$HOME/cargo-shim" +cat > "$HOME/cargo-shim/cargo" <<'EOSHIM' +#!/usr/bin/env bash +# `cargo auditable ` forwards transparently to cargo, so an +# unconditional prefix is safe for the version/metadata/build calls +# setuptools_rust makes. +exec cargo auditable "$@" +EOSHIM +chmod +x "$HOME/cargo-shim/cargo" """ environment = """ PATH=$HOME/.local/bin:$HOME/.cargo/bin:$PATH SCCACHE_DIR=$HOME/.cache/sccache SCCACHE_CACHE_SIZE=2G RUSTC_WRAPPER=sccache +CARGO=$HOME/cargo-shim/cargo """ [tool.cibuildwheel.macos] @@ -36,10 +50,27 @@ skip = "pp* cp???t-*" before-build = """ uv pip install --system "numpy>=2,<3" --config-settings=setup-args="-Dallow-noblas=true" rustup target add aarch64-apple-darwin -[ -e $HOME/.local/bin/sccache ] || ./.travis/setup-sccache.sh ] +[ -e $HOME/.local/bin/sccache ] || ./.travis/setup-sccache.sh ] +# Same cargo-auditable shim as Linux (see comment there). +cargo install --locked cargo-auditable --version 0.7.0 +mkdir -p "$HOME/cargo-shim" +cat > "$HOME/cargo-shim/cargo" <<'EOSHIM' +#!/usr/bin/env bash +# `cargo auditable ` forwards transparently to cargo, so an +# unconditional prefix is safe for the version/metadata/build calls +# setuptools_rust makes. +exec cargo auditable "$@" +EOSHIM +chmod +x "$HOME/cargo-shim/cargo" +""" +environment = """ +CARGO=$HOME/cargo-shim/cargo """ [tool.cibuildwheel.windows] +# TODO: add a .cmd shim equivalent of the Linux/macOS bash shim above to +# enable cargo-auditable here too. Windows wheels currently ship without +# the embedded Rust dep graph. before-build = """ choco install mingw --version=8.1.0 uv pip install --system "numpy==1.25.2" diff --git a/api/py/tests/mobilenet_onnx_test.py b/api/py/tests/mobilenet_onnx_test.py index 507c59af11..ba5451df25 100644 --- a/api/py/tests/mobilenet_onnx_test.py +++ b/api/py/tests/mobilenet_onnx_test.py @@ -99,7 +99,7 @@ def test_concretize(): typed = model.into_model() assert str(typed.input_fact(0)) == "B,3,224,224,f32" assert str(typed.output_fact(0)) == "B,1000,f32" - typed.transform(tract.ConcretizeSymbols({"B": 1})) + typed.transform(tract.SetSymbols({"B": 1})) assert str(typed.input_fact(0)) == "1,3,224,224,f32" assert str(typed.output_fact(0)) == "1,1000,f32" @@ -108,7 +108,7 @@ def test_concretize_builder(): model.set_input_fact(0, "B,3,224,224,f32") model.analyse() typed = model.into_model() - typed.transform(tract.ConcretizeSymbols().value("B", 1)) + typed.transform(tract.SetSymbols().value("B", 1)) assert str(typed.input_fact(0)) == "1,3,224,224,f32" assert str(typed.output_fact(0)) == "1,1000,f32" @@ -117,7 +117,7 @@ def test_concretize_raw_string(): model.set_input_fact(0, "B,3,224,224,f32") model.analyse() typed = model.into_model() - typed.transform('{"name":"concretize_symbols","values":{"B":1}}') + typed.transform('{"name":"set_symbols","values":{"B":1}}') assert str(typed.input_fact(0)) == "1,3,224,224,f32" assert str(typed.output_fact(0)) == "1,1000,f32" diff --git a/api/py/tract/__init__.py b/api/py/tract/__init__.py index 40c5166bf3..940bf4c639 100644 --- a/api/py/tract/__init__.py +++ b/api/py/tract/__init__.py @@ -49,7 +49,7 @@ from .inference_model import InferenceModel from .runnable import Runnable from .runtime import Runtime, runtime_for_name -from .transform import TransformSpec, ConcretizeSymbols, FloatPrecision, Pulse +from .transform import TransformSpec, SetSymbols, FloatPrecision, Pulse from .nnef import Nnef from .onnx import Onnx diff --git a/api/py/tract/model.py b/api/py/tract/model.py index ead15407d4..a143715018 100644 --- a/api/py/tract/model.py +++ b/api/py/tract/model.py @@ -109,7 +109,7 @@ def transform(self, transform: Union[str, "TransformSpec"]) -> None: - a plain string name (e.g. ``"f32_to_f16"``) - a JSON string with a ``"name"`` key and parameters - - a :class:`TransformSpec` subclass such as :class:`ConcretizeSymbols` + - a :class:`TransformSpec` subclass such as :class:`SetSymbols` or :class:`Pulse` """ self._valid() diff --git a/api/py/tract/transform.py b/api/py/tract/transform.py index 2926040bc4..35299d70b6 100644 --- a/api/py/tract/transform.py +++ b/api/py/tract/transform.py @@ -18,26 +18,28 @@ def to_json(self) -> str: ... -class ConcretizeSymbols(TransformSpec): - """Replace symbolic dimensions with concrete integer values. +class SetSymbols(TransformSpec): + """Bind symbolic dimensions to concrete integer values (or TDim expressions). Example:: - model.transform(ConcretizeSymbols({"B": 1})) + model.transform(SetSymbols({"B": 1})) # or with builder pattern: - model.transform(ConcretizeSymbols().value("B", 1)) + model.transform(SetSymbols().value("B", 1)) + # for symbolic TDim expressions, pass a string: + model.transform(SetSymbols().value("B", "2*S")) """ - def __init__(self, values: Optional[Dict[str, int]] = None): - self._values: Dict[str, int] = dict(values) if values else {} + def __init__(self, values: Optional[Dict[str, Union[int, str]]] = None): + self._values: Dict[str, Union[int, str]] = dict(values) if values else {} - def value(self, symbol: str, val: int) -> "ConcretizeSymbols": - """Set a symbol to a concrete value. Returns self for chaining.""" + def value(self, symbol: str, val: Union[int, str]) -> "SetSymbols": + """Bind a symbol to an int or TDim expression string. Returns self for chaining.""" self._values[symbol] = val return self def to_json(self) -> str: - return json.dumps({"name": "concretize_symbols", "values": self._values}) + return json.dumps({"name": "set_symbols", "values": self._values}) class Pulse(TransformSpec): diff --git a/api/rs/Cargo.toml b/api/rs/Cargo.toml index 457a943486..ca1014377f 100644 --- a/api/rs/Cargo.toml +++ b/api/rs/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tract" -version = "0.23.0-pre" +version = "0.23.1-pre" license = "MIT OR Apache-2.0" authors = ["Mathieu Poumeyrol "] description = "Tiny, no-nonsense, self contained, TensorFlow and ONNX inference" diff --git a/deny.toml b/api/rs/deny.toml similarity index 97% rename from deny.toml rename to api/rs/deny.toml index cf3ebfbef2..66aa54784f 100644 --- a/deny.toml +++ b/api/rs/deny.toml @@ -34,7 +34,6 @@ deny = [ # Skip some multiple-versions checks, until they can be fixed. skip = [ { name = "hashbrown", version="<0.17" }, - { name = "foldhash", version="<0.2" }, { name = "cpufeatures", version="<0.3" }, { name = "cfg-if", version="<1" }, { name = "getrandom", version="<0.4" }, diff --git a/api/rs/src/lib.rs b/api/rs/src/lib.rs index 636829d586..bfafde8a3b 100644 --- a/api/rs/src/lib.rs +++ b/api/rs/src/lib.rs @@ -39,7 +39,7 @@ pub mod prelude { // User-facing API types pub use tract_api::{ - ConcretizeSymbols, Datum, DatumType, FloatPrecision, Pulse, TransformSpec, tensor, + Datum, DatumType, FloatPrecision, Pulse, SetSymbols, TransformSpec, tensor, }; // Traits needed for method resolution — hidden from namespace @@ -345,7 +345,10 @@ impl RunnableInterface for Runnable { type Fact = Fact; fn run(&self, inputs: impl IntoInputs) -> Result> { - StateInterface::run(&mut self.spawn_state()?, inputs.into_inputs()?) + let inputs: TVec = + inputs.into_inputs()?.into_iter().map(|v| v.0.into_tvalue()).collect(); + let outputs = self.0.run(inputs)?; + Ok(outputs.into_iter().map(|t| Tensor(t.into_arc_tensor())).collect()) } fn spawn_state(&self) -> Result { diff --git a/api/src/lib.rs b/api/src/lib.rs index df3a35d6fc..4617529ee8 100644 --- a/api/src/lib.rs +++ b/api/src/lib.rs @@ -7,7 +7,7 @@ use std::path::Path; pub mod macros; pub mod transform; -pub use transform::{ConcretizeSymbols, FloatPrecision, Pulse, TransformConfig, TransformSpec}; +pub use transform::{FloatPrecision, Pulse, SetSymbols, TransformConfig, TransformSpec}; /// an implementation of tract's NNEF framework object /// diff --git a/api/src/transform.rs b/api/src/transform.rs index 91d4d09fd3..3f11359b0b 100644 --- a/api/src/transform.rs +++ b/api/src/transform.rs @@ -74,32 +74,60 @@ macro_rules! transform_config { }; } -/// Typed config for the `concretize_symbols` transform. +/// Typed config for the `set_symbols` transform. /// -/// Replaces symbolic dimensions with concrete integer values. +/// Binds symbolic dimensions to concrete integers (or `TDim` expressions +/// via [`Self::expr`]). /// /// # Example /// ```ignore -/// model.transform(ConcretizeSymbols::new().value("B", 1))?; +/// model.transform(SetSymbols::new().value("B", 1).value("T", 16))?; /// ``` #[derive(Debug, Clone, Default, serde::Serialize)] -pub struct ConcretizeSymbols { - values: HashMap, +pub struct SetSymbols { + #[serde(serialize_with = "serialize_values")] + values: HashMap, } -impl ConcretizeSymbols { +#[derive(Debug, Clone, serde::Serialize)] +#[serde(untagged)] +enum SetSymbolValue { + Int(i64), + Expr(String), +} + +fn serialize_values( + values: &HashMap, + s: S, +) -> Result { + use serde::ser::SerializeMap; + let mut map = s.serialize_map(Some(values.len()))?; + for (k, v) in values { + map.serialize_entry(k, v)?; + } + map.end() +} + +impl SetSymbols { pub fn new() -> Self { Self::default() } - /// Set a symbol to a concrete value. + /// Bind a symbol to a concrete integer value. pub fn value(mut self, symbol: impl Into, val: i64) -> Self { - self.values.insert(symbol.into(), val); + self.values.insert(symbol.into(), SetSymbolValue::Int(val)); + self + } + + /// Bind a symbol to a `TDim` expression (e.g. `"2*S"`) parsed against + /// the model's symbol scope at transform time. + pub fn expr(mut self, symbol: impl Into, expr: impl Into) -> Self { + self.values.insert(symbol.into(), SetSymbolValue::Expr(expr.into())); self } } -transform_config!(ConcretizeSymbols, "concretize_symbols"); +transform_config!(SetSymbols, "set_symbols"); /// Typed config for the `pulse` transform. /// diff --git a/api/tests/mobilenet/mod.rs b/api/tests/mobilenet/mod.rs index 2487a2c999..62e3ccc420 100644 --- a/api/tests/mobilenet/mod.rs +++ b/api/tests/mobilenet/mod.rs @@ -123,7 +123,7 @@ fn test_concretize() -> anyhow::Result<()> { let mut typed = model.into_model()?; assert_eq!(typed.input_fact(0)?.to_string(), "B,3,224,224,f32"); assert_eq!(typed.output_fact(0)?.to_string(), "B,1000,f32"); - typed.transform(ConcretizeSymbols::new().value("B", 1))?; + typed.transform(SetSymbols::new().value("B", 1))?; assert_eq!(typed.input_fact(0)?.to_string(), "1,3,224,224,f32"); assert_eq!(typed.output_fact(0)?.to_string(), "1,1000,f32"); Ok(()) @@ -136,7 +136,7 @@ fn test_concretize_raw_string() -> anyhow::Result<()> { model.set_input_fact(0, "B,3,224,224,f32")?; model.analyse()?; let mut typed = model.into_model()?; - typed.transform(r#"{"name":"concretize_symbols","values":{"B":1}}"#)?; + typed.transform(r#"{"name":"set_symbols","values":{"B":1}}"#)?; assert_eq!(typed.input_fact(0)?.to_string(), "1,3,224,224,f32"); assert_eq!(typed.output_fact(0)?.to_string(), "1,1000,f32"); Ok(()) diff --git a/cli/Cargo.toml b/cli/Cargo.toml index e17edb2504..0673673594 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tract-cli" -version = "0.23.0-pre" +version = "0.23.1-pre" authors = [ "Romain Liautaud ", "Mathieu Poumeyrol ", @@ -22,7 +22,6 @@ path = "src/main.rs" maintenance = { status = "actively-developed" } [dependencies] -atty.workspace = true box_drawing.workspace = true clap.workspace = true erased-serde.workspace = true @@ -74,8 +73,8 @@ float-ord.workspace = true tract-metal.workspace = true [target.'cfg(any(target_os = "linux", target_os = "windows"))'.dependencies] -cudarc.workspace = true -tract-cuda = { workspace = true, default-features = false } +cudarc = { workspace = true, optional = true } +tract-cuda = { workspace = true, optional = true } [features] default = [ @@ -96,23 +95,29 @@ pulse = ["tract-pulse", "tract-pulse-opl"] tf = ["tract-tensorflow", "tract-libcli/hir"] tflite = ["tract-tflite"] transformers = ["tract-transformers", "tract-libcli/transformers"] -conform = ["tract-tensorflow/conform"] multithread-mm = ["tract-linalg/multithread-mm"] +# Marker feature implicitly enabled by every cuda-XXXXX selector below. +# Use to gate cudarc-specific code paths in tract-cli. Not meant to be +# enabled directly — always pick a concrete cuda-XXXXX so cudarc has an +# API version to bind against. +cuda = [] + # CUDA driver-API selectors (linux/windows targets only). Picking one binds # cudarc against the matching CUDA enum/struct layout; the resulting binary # runs on any driver whose API is >= the chosen one. Exactly one must be # active. The default pulls cuda-13000 in; cross-compiles aimed at older -# drivers (e.g. aarch64 with CUDA 12) override with `--no-default-features`. -cuda-12000 = ["tract-cuda/cuda-12000"] -cuda-12010 = ["tract-cuda/cuda-12010"] -cuda-12020 = ["tract-cuda/cuda-12020"] -cuda-12030 = ["tract-cuda/cuda-12030"] -cuda-12040 = ["tract-cuda/cuda-12040"] -cuda-12050 = ["tract-cuda/cuda-12050"] -cuda-12060 = ["tract-cuda/cuda-12060"] -cuda-12080 = ["tract-cuda/cuda-12080"] -cuda-12090 = ["tract-cuda/cuda-12090"] -cuda-13000 = ["tract-cuda/cuda-13000"] -cuda-13010 = ["tract-cuda/cuda-13010"] -cuda-13020 = ["tract-cuda/cuda-13020"] +# drivers (e.g. aarch64 with CUDA 12) override with `--no-default-features` +# plus an explicit cuda-12XXX. +cuda-12000 = ["cuda", "dep:cudarc", "dep:tract-cuda", "tract-libcli/cuda", "tract-cuda/cuda-12000"] +cuda-12010 = ["cuda", "dep:cudarc", "dep:tract-cuda", "tract-libcli/cuda", "tract-cuda/cuda-12010"] +cuda-12020 = ["cuda", "dep:cudarc", "dep:tract-cuda", "tract-libcli/cuda", "tract-cuda/cuda-12020"] +cuda-12030 = ["cuda", "dep:cudarc", "dep:tract-cuda", "tract-libcli/cuda", "tract-cuda/cuda-12030"] +cuda-12040 = ["cuda", "dep:cudarc", "dep:tract-cuda", "tract-libcli/cuda", "tract-cuda/cuda-12040"] +cuda-12050 = ["cuda", "dep:cudarc", "dep:tract-cuda", "tract-libcli/cuda", "tract-cuda/cuda-12050"] +cuda-12060 = ["cuda", "dep:cudarc", "dep:tract-cuda", "tract-libcli/cuda", "tract-cuda/cuda-12060"] +cuda-12080 = ["cuda", "dep:cudarc", "dep:tract-cuda", "tract-libcli/cuda", "tract-cuda/cuda-12080"] +cuda-12090 = ["cuda", "dep:cudarc", "dep:tract-cuda", "tract-libcli/cuda", "tract-cuda/cuda-12090"] +cuda-13000 = ["cuda", "dep:cudarc", "dep:tract-cuda", "tract-libcli/cuda", "tract-cuda/cuda-13000"] +cuda-13010 = ["cuda", "dep:cudarc", "dep:tract-cuda", "tract-libcli/cuda", "tract-cuda/cuda-13010"] +cuda-13020 = ["cuda", "dep:cudarc", "dep:tract-cuda", "tract-libcli/cuda", "tract-cuda/cuda-13020"] diff --git a/cli/deny.toml b/cli/deny.toml new file mode 100644 index 0000000000..bee6142bb0 --- /dev/null +++ b/cli/deny.toml @@ -0,0 +1,53 @@ + +# add whatever else we support. +[graph] +targets = [ + { triple = "x86_64-unknown-linux-gnu" }, + { triple = "x86_64-unknown-linux-musl" }, + { triple = "x86_64-apple-darwin" }, + { triple = "x86_64-pc-windows-msvc" }, + { triple = "aarch64-linux-android" }, + { triple = "aarch64-unknown-linux-gnu" }, + { triple = "aarch64-unknown-linux-musl" }, + { triple = "aarch64-apple-ios" }, + { triple = "aarch64-apple-darwin" }, + { triple = "armv7-unknown-linux-gnueabihf" }, + { triple = "armv7-unknown-linux-musleabi" }, + { triple = "arm-unknown-linux-gnueabihf" }, + { triple = "wasm32-unknown-unknown" }, +] + +[advisories] +git-fetch-with-cli = true +yanked = "deny" +ignore = [ + "RUSTSEC-2024-0436", # paste unmaintained — transitive dep from metal, tokenizers, rav1e +] + +[bans] +multiple-versions = "allow" +wildcards = "allow" +deny = [ + # List crates we don't want in our dependency tree here. +] + +[sources] +# trusted git sources. +allow-git = [ + "https://github.com/rustformers/llm.git", +] + +[licenses] +allow = [ + "Apache-2.0", # https://tldrlegal.com/license/apache-license-2.0-(apache-2.0) + "MIT", # https://tldrlegal.com/license/mit-license + "Unicode-3.0", # https://spdx.org/licenses/Unicode-3.0.html + "Zlib", # https://tldrlegal.com/license/zlib-libpng-license + "ISC", # https://tldrlegal.com/license/isc-license + "MPL-2.0", # https://tldrlegal.com/license/mozilla-public-license-2.0-(mpl-2) + "BSD-3-Clause", # https://www.tldrlegal.com/license/bsd-3-clause-license-revised + "CDLA-Permissive-2.0", # https://cdla.dev/permissive-2-0/ +] + +clarify = [ +] diff --git a/cli/src/bench.rs b/cli/src/bench.rs index 445e2ab00f..630679d989 100644 --- a/cli/src/bench.rs +++ b/cli/src/bench.rs @@ -33,7 +33,7 @@ pub fn handle( limits.warmup(¶ms.req_runnable()?, &inputs)?; let (iters, dur) = { - #[cfg(any(target_os = "linux", target_os = "windows"))] + #[cfg(all(any(target_os = "linux", target_os = "windows"), feature = "cuda"))] let _profiler = sub_matches.get_flag("cuda-gpu-trace").then(cudarc::driver::safe::Profiler::new); limits.bench(¶ms.req_runnable()?, &inputs)? diff --git a/cli/src/compare.rs b/cli/src/compare.rs index b16650f24c..a2785cc4d9 100644 --- a/cli/src/compare.rs +++ b/cli/src/compare.rs @@ -26,7 +26,6 @@ pub fn handle( } let cumulative = sub_matches.get_flag("cumulative"); - let resilent = sub_matches.get_flag("resilient"); if sub_matches.get_one::("stage").is_some() { // --with is by pipeline and put in params return handle_reference_stage(cumulative, params, &output_params, &run_params); @@ -38,100 +37,9 @@ pub fn handle( if let Some(pbdir) = sub_matches.get_one::("pbdir") { return handle_pbdir(cumulative, pbdir, params, &output_params, &run_params); } - if sub_matches.get_flag("tf") { - return handle_tensorflow(cumulative, resilent, params, &output_params, &run_params); - } bail!("No comparison target found") } -#[cfg(not(feature = "conform"))] -pub fn handle_tensorflow( - _cumulative: bool, - _resilient: bool, - _params: &mut Parameters, - _output_params: &DisplayParams, - _run_params: &RunParams, -) -> TractResult<()> { - bail!("`tf` feature is required for this to work"); -} - -#[cfg(feature = "conform")] -pub fn handle_tensorflow( - cumulative: bool, - resilient: bool, - params: &mut Parameters, - output_params: &DisplayParams, - run_params: &RunParams, -) -> TractResult<()> { - let tract = ¶ms.tract_model; - let mut tf = params.tf_model.take().unwrap(); - // First generate random values for the inputs. - let input_facts = tract - .input_outlets() - .iter() - .map(|&i| tract.outlet_typedfact(i)) - .collect::>>()?; - let generated = crate::tensor::make_inputs(&*input_facts)?; - - // Execute the model on tensorflow first. - info!("Running the model on tensorflow."); - trace!("Inject inputs in tensorflow graph."); - let pairs: Vec<_> = tract - .input_outlets() - .iter() - .map(|s| &*tract.node_name(s.node)) - .zip(generated.iter().cloned()) - .collect(); - - trace!("Execute the model on tensorflow."); - let eval_order = tract.eval_order()?; - - let mut wanted_outputs: Vec<&str> = eval_order - .iter() - .filter(|&n| !tract.input_outlets().contains(&OutletId::new(*n, 0))) - .map(|&n| tract.node_name(n)) - .collect(); - - for o in tract.output_outlets() { - let name = &*tract.node_name(o.node); - if !wanted_outputs.contains(&name) { - wanted_outputs.push(name); - } - } - - let mut all_values: HashMap>> = HashMap::new(); - if resilient { - for name in wanted_outputs { - all_values.insert( - name.to_string(), - vec![ - tf.run(pairs.clone(), &name) - .map(|t| Arc::new(t[0].clone().into())) - .map_err(|e| e.into()), - ], - ); - } - } else { - tf.run_get_many(pairs, wanted_outputs)?.into_iter().for_each(|(k, v)| { - all_values.insert(k.to_string(), vec![Ok(v[0].clone().into())]); - }); - }; - - for (ix, input) in tract.input_outlets().iter().enumerate() { - let name = tract.node_name(input.node); - all_values.insert(name.to_string(), vec![Ok(generated[ix].clone().into_arc_tensor())]); - } - dispatch_model_no_pulse!(params.tract_model, |m| compare( - cumulative, - m, - &all_values, - ¶ms, - &output_params, - run_params, - ("tract", "tf"), - )) -} - pub fn handle_npz( cumulative: bool, npz: &str, @@ -516,8 +424,7 @@ pub fn handle_stream( } // Concretize the reference model and delegate to compare() - let concrete_ref = - Arc::new(reference.clone().substitute_symbols(&concrete_sym_values.to_dim_map())?); + let concrete_ref = Arc::new(reference.clone().set_symbols(&concrete_sym_values.to_dim_map())?); compare( false, &concrete_ref, diff --git a/cli/src/dump.rs b/cli/src/dump.rs index 03273f05fe..2fd3acc539 100644 --- a/cli/src/dump.rs +++ b/cli/src/dump.rs @@ -11,7 +11,7 @@ use tract_core::ops::matmul::optimized::{OptMatMul, ProtoFusedSpec}; use tract_core::ops::matmul::pack::DynPackedExoticFact; use tract_core::ops::scan::OptScan; #[allow(unused_imports)] -#[cfg(any(target_os = "linux", target_os = "windows"))] +#[cfg(all(any(target_os = "linux", target_os = "windows"), feature = "cuda"))] use tract_cuda::utils::ensure_cuda_runtime_dependencies; use tract_hir::internal::*; use tract_itertools::Itertools; @@ -146,17 +146,17 @@ pub fn handle( #[cfg(not(any(target_os = "macos", target_os = "ios")))] { - #[cfg(any(target_os = "linux", target_os = "windows"))] + #[cfg(all(any(target_os = "linux", target_os = "windows"), feature = "cuda"))] ensure_cuda_runtime_dependencies("GPU profiling called on non-GPU device")?; } - #[cfg(any(target_os = "linux", target_os = "windows"))] + #[cfg(all(any(target_os = "linux", target_os = "windows"), feature = "cuda"))] let is_cuda = matches.get_flag("cuda"); - #[cfg(not(any(target_os = "linux", target_os = "windows")))] + #[cfg(not(all(any(target_os = "linux", target_os = "windows"), feature = "cuda")))] let is_cuda = false; if is_cuda { - #[cfg(any(target_os = "linux", target_os = "windows"))] + #[cfg(all(any(target_os = "linux", target_os = "windows"), feature = "cuda"))] tract_cuda::with_cuda_stream(|s| { s.enable_profiling(); Ok(()) @@ -164,7 +164,7 @@ pub fn handle( } let before_node: Box = if is_cuda { - #[cfg(any(target_os = "linux", target_os = "windows"))] + #[cfg(all(any(target_os = "linux", target_os = "windows"), feature = "cuda"))] { Box::new(|node_id| { tract_cuda::with_cuda_stream(|s| { @@ -174,7 +174,10 @@ pub fn handle( .ok(); }) } - #[cfg(not(any(target_os = "linux", target_os = "windows")))] + #[cfg(not(all( + any(target_os = "linux", target_os = "windows"), + feature = "cuda" + )))] Box::new(|_| {}) } else { Box::new(|_| {}) @@ -186,7 +189,7 @@ pub fn handle( &[(usize, String)], ) -> TractResult<()>, > = if is_cuda { - #[cfg(any(target_os = "linux", target_os = "windows"))] + #[cfg(all(any(target_os = "linux", target_os = "windows"), feature = "cuda"))] { Box::new(|dg, prefix| { tract_cuda::with_cuda_stream(|s| { @@ -209,7 +212,10 @@ pub fn handle( }) }) } - #[cfg(not(any(target_os = "linux", target_os = "windows")))] + #[cfg(not(all( + any(target_os = "linux", target_os = "windows"), + feature = "cuda" + )))] Box::new(|_, _| Ok(())) } else { Box::new(|_, _| Ok(())) @@ -254,12 +260,20 @@ pub fn handle( } if sub_matches.contains_id("memory-arena") { - #[cfg(not(any(target_os = "macos", target_os = "ios")))] + #[cfg(all(any(target_os = "linux", target_os = "windows"), feature = "cuda"))] { ensure_cuda_runtime_dependencies( "Memory arena is only enabled for MacOS / iOS devices or CUDA devices", )?; } + #[cfg(not(any( + target_os = "macos", + target_os = "ios", + all(any(target_os = "linux", target_os = "windows"), feature = "cuda") + )))] + bail!( + "Memory arena requires CUDA (Linux/Windows with cuda-* feature) or Metal (macOS/iOS)" + ); crate::memory_arena::dump_metrics( ¶ms.req_typed_model(), &plan_options_from_subcommand(sub_matches)?, diff --git a/cli/src/hwbench.rs b/cli/src/hwbench.rs index 13fe3a32e4..9725a3dd77 100644 --- a/cli/src/hwbench.rs +++ b/cli/src/hwbench.rs @@ -1,3 +1,5 @@ +use std::io::IsTerminal; + use nu_ansi_term::Color::*; use tract_core::prelude::*; use tract_core::tract_data::itertools::Itertools; @@ -107,7 +109,7 @@ fn mmm(dt: DatumType, m: usize, k: usize, n: usize) -> TractResult<()> { pa.precursor().as_dt() == Some(dt) && pb.precursor().as_dt() == Some(dt) }) .map(|(mmm, pix, pa, pb)| { - if atty::is(atty::Stream::Stderr) { + if std::io::stderr().is_terminal() { eprint!("Benching {} ({pix})", mmm.name()); } let a = pa.prepare_one(&a, 1, 0).unwrap(); @@ -132,7 +134,7 @@ fn mmm(dt: DatumType, m: usize, k: usize, n: usize) -> TractResult<()> { .unwrap(); } }); - if atty::is(atty::Stream::Stderr) { + if std::io::stderr().is_terminal() { eprint!("\x1B[2K\r"); // clear current line + CR } let flops = (m * k * n) as f64 / time; diff --git a/cli/src/main.rs b/cli/src/main.rs index bde9564ce1..215e8f30b1 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -186,9 +186,6 @@ fn main() -> TractResult<()> { .value_parser(clap::builder::PossibleValuesParser::new(STAGES)) .help("Loading pipeline stage to compare with"), ) - .arg( - Arg::new("tf").long("tf").action(ArgAction::SetTrue).help("Compare against tensorflow"), - ) .arg( Arg::new("twice") .long("twice") @@ -210,7 +207,7 @@ fn main() -> TractResult<()> { ) .group( ArgGroup::new("reference") - .args(&["npz", "pbdir", "stage", "tf", "twice", "stream"]) + .args(&["npz", "pbdir", "stage", "twice", "stream"]) .required(true), ) .arg( @@ -559,7 +556,8 @@ fn run_options(command: clap::Command) -> clap::Command { .long("set") .action(clap::ArgAction::Append) .number_of_values(1) - .help("Set a symbol value before running the model (--set S=12)"), + .help("Bind a symbol before running the model. RHS is a TDim expression \ + reduced to i64 against symbols set so far (--set S=12, --set T=2*S)."), ) .arg( Arg::new("input-from-nnef").long("input-from-nnef").num_args(1).help( diff --git a/cli/src/params.rs b/cli/src/params.rs index 345287b33f..09a434d9ec 100644 --- a/cli/src/params.rs +++ b/cli/src/params.rs @@ -134,13 +134,6 @@ pub struct Parameters { pub tract_model: Arc, pub reference_model: Option>, - #[cfg(feature = "conform")] - pub tf_model: Option, - - #[cfg(not(feature = "conform"))] - #[allow(dead_code)] - pub tf_model: (), - pub tensors_values: TensorsValues, pub assertions: Assertions, @@ -460,7 +453,7 @@ impl Parameters { /// `value` may be a plain integer or any TDim expression parseable /// against the model's symbol scope (e.g. `2*S` to rebase the /// streaming symbol onto a finer-grained chunk symbol before - /// pulsification). Feeds straight into `model.substitute_symbols`. + /// pulsification). Feeds straight into `model.set_symbols`. pub fn parse_set_subs( typed_model: &TypedModel, set: impl Iterator>, @@ -778,15 +771,15 @@ impl Parameters { } if let Some(set) = matches.get_many::("set") { - // --set delegates to model.substitute_symbols with a - // Symbol → TDim map (same path the `concretize_symbols` + // --set delegates to model.set_symbols with a + // Symbol → TDim map (same path the `set_symbols` // model transform takes). Values may be plain integers or // TDim expressions (e.g. `--set T=2*S` to rebase the // streaming symbol). Const TDim tensors are rewritten - // through Const's own substitute_symbols hook. + // through Const's own set_symbols hook. let subs = Self::parse_set_subs(typed_model.as_ref().unwrap(), set)?; stage!("set", typed_model -> typed_model, move |m: TypedModel| { - m.substitute_symbols(&subs) + m.set_symbols(&subs) }); stage!("set-declutter", typed_model -> typed_model, |mut m| { let mut dec = tract_core::optim::Optimizer::declutter(); @@ -859,36 +852,17 @@ impl Parameters { info!("Model {filename:?} loaded"); info_usage("model loaded", probe); - let (need_tensorflow_model, need_reference_model) = match matches.subcommand() { + let need_reference_model = match matches.subcommand() { Some(("compare", sm)) => { if let Some(with) = sm.get_one::("stage").map(String::as_str) { - (false, Some(with)) + Some(with) } else if sm.get_flag("stream") { - (false, Some("declutter")) + Some("declutter") } else { - (true, None) - } - } - _ => (false, None), - }; - - #[cfg(not(feature = "conform"))] - let tf_model = (); - #[cfg(feature = "conform")] - let tf_model = if need_tensorflow_model { - info!("Tensorflow version: {}", tract_tensorflow::conform::tf::version()); - if matches.get_flag("determinize") { - if let SomeGraphDef::Tf(ref graph) = graph { - let graph = graph.write_to_bytes().unwrap(); - Some(tract_tensorflow::conform::tf::for_slice(&graph)?) - } else { - unreachable!() + None } - } else { - Some(tract_tensorflow::conform::tf::for_path(&filename)?) } - } else { - None + _ => None, }; let need_proto = matches.get_flag("proto") @@ -1062,7 +1036,6 @@ impl Parameters { runnable, tract_model, reference_model, - tf_model, tensors_values, assertions, machine_friendly: matches.get_flag("machine-friendly"), diff --git a/cli/src/tensor.rs b/cli/src/tensor.rs index f59ad6e745..f34826b3f6 100644 --- a/cli/src/tensor.rs +++ b/cli/src/tensor.rs @@ -72,13 +72,20 @@ pub fn run_params_from_subcommand( } if let Some(set) = sub_matches.get_many::("set") { + // Right-hand side is a TDim expression (e.g. `--set T=2*S`). Parse + // against the model's symbol scope, then reduce to `i64` with the + // symbols already set so far (CLI argument order matters when + // expressions reference other symbols). + let symbol_scope = params.tract_model.symbols(); for set in set { let set = set.as_str(); - let (sym, value) = set.split_once('=').context("--set expect S=12 form")?; + let (sym, value) = set.split_once('=').context("--set expects S=value form")?; + let dim = tract_core::internal::parse_tdim(symbol_scope, value) + .with_context(|| format!("--set: parsing TDim expression for {sym}={value}"))?; + let value: i64 = dim.eval_to_i64(&symbols).with_context(|| { + format!("--set {sym}={value}: resolving with current symbol values {symbols:?}") + })?; let sym = params.tract_model.get_or_intern_symbol(sym); - let value: i64 = value - .parse() - .with_context(|| format!("Can not parse symbol value in set {set}"))?; symbols.set(&sym, value); } } diff --git a/core/Cargo.toml b/core/Cargo.toml index f2a31c93b1..814043e490 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tract-core" -version = "0.23.0-pre" +version = "0.23.1-pre" license = "MIT OR Apache-2.0" authors = ["Mathieu Poumeyrol "] description = "Tiny, no-nonsense, self contained, TensorFlow and ONNX inference" diff --git a/core/src/model/graph.rs b/core/src/model/graph.rs index 90d460cbb5..7d28c2bd1c 100644 --- a/core/src/model/graph.rs +++ b/core/src/model/graph.rs @@ -172,6 +172,24 @@ where Ok(self) } + /// Set model inputs by node name — mirror of [`Self::select_outputs_by_name`]. + /// Removed inputs become dangling Source nodes; declutter prunes them. + pub fn select_inputs_by_name( + &mut self, + inputs: impl IntoIterator>, + ) -> TractResult<()> { + self.set_input_names(inputs) + } + + /// Set model inputs by node name and return `self`. + pub fn with_inputs_by_name( + mut self, + inputs: impl IntoIterator>, + ) -> TractResult { + self.select_inputs_by_name(inputs)?; + Ok(self) + } + /// Get the `ix`-th input tensor type information. pub fn input_fact(&self, ix: usize) -> TractResult<&F> { let input = self.input_outlets()?[ix]; diff --git a/core/src/model/typed.rs b/core/src/model/typed.rs index ea6b2c0fe1..64e730ee66 100644 --- a/core/src/model/typed.rs +++ b/core/src/model/typed.rs @@ -256,7 +256,9 @@ impl TypedModel { Ok(()) } - pub fn substitute_symbols(&self, subs: &HashMap) -> TractResult { + /// Bind one or more symbols to concrete values or TDim expressions across + /// the whole graph. + pub fn set_symbols(&self, subs: &HashMap) -> TractResult { crate::model::translator::Translate::translate_model(subs, self) } @@ -320,7 +322,7 @@ impl Translate, TypedFact, Box> for Has mapping: &HashMap, ) -> TractResult> { target.check_consistency()?; - let outlets = node.op.substitute_symbols(source, node, target, mapping, self)?; + let outlets = node.op.set_symbols(source, node, target, mapping, self)?; for &outlet in &outlets { let fact = &mut target.nodes[outlet.node].outputs[outlet.slot].fact; if fact.shape.volume().is_zero() diff --git a/core/src/ops/array/broadcast.rs b/core/src/ops/array/broadcast.rs index 1b86834413..d6fccab489 100644 --- a/core/src/ops/array/broadcast.rs +++ b/core/src/ops/array/broadcast.rs @@ -124,7 +124,7 @@ impl TypedOp for MultiBroadcastTo { crate::optim::propagate_roi::bubble_roi(model, node) } - fn substitute_symbols( + fn set_symbols( &self, _source: &TypedModel, node: &TypedNode, diff --git a/core/src/ops/array/dyn_slice.rs b/core/src/ops/array/dyn_slice.rs index 2b1de3c5c1..050a62d1f4 100644 --- a/core/src/ops/array/dyn_slice.rs +++ b/core/src/ops/array/dyn_slice.rs @@ -114,5 +114,18 @@ impl TypedOp for DynSlice { )?)) } + fn set_symbols( + &self, + _source: &TypedModel, + node: &TypedNode, + target: &mut TypedModel, + mapping: &HashMap, + subs: &HashMap, + ) -> TractResult> { + let op = DynSlice { axis: self.axis, len: self.len.substitute_all(subs)? }; + let inputs = node.inputs.iter().map(|i| mapping[i]).collect::>(); + target.wire_node(&node.name, op, &inputs) + } + as_op!(); } diff --git a/core/src/ops/array/slice.rs b/core/src/ops/array/slice.rs index 15a12fb328..e5ccf84a36 100644 --- a/core/src/ops/array/slice.rs +++ b/core/src/ops/array/slice.rs @@ -178,7 +178,7 @@ impl TypedOp for Slice { } } - fn substitute_symbols( + fn set_symbols( &self, _source: &TypedModel, node: &TypedNode, diff --git a/core/src/ops/array/tile.rs b/core/src/ops/array/tile.rs index 00e7e10a33..3988779dbf 100644 --- a/core/src/ops/array/tile.rs +++ b/core/src/ops/array/tile.rs @@ -45,7 +45,7 @@ impl EvalOp for Tile { impl TypedOp for Tile { as_op!(); - fn substitute_symbols( + fn set_symbols( &self, _source: &TypedModel, node: &TypedNode, diff --git a/core/src/ops/array/topk.rs b/core/src/ops/array/topk.rs index 024de8c8db..a98d36cf43 100644 --- a/core/src/ops/array/topk.rs +++ b/core/src/ops/array/topk.rs @@ -103,5 +103,22 @@ impl TypedOp for Topk { Ok(tvec!(fact_values, fact_indices)) } + fn set_symbols( + &self, + _source: &TypedModel, + node: &TypedNode, + target: &mut TypedModel, + mapping: &HashMap, + subs: &HashMap, + ) -> TractResult> { + let op = Topk { + axis: self.axis, + largest: self.largest, + fallback_k: self.fallback_k.substitute_all(subs)?, + }; + let inputs = node.inputs.iter().map(|i| mapping[i]).collect::>(); + target.wire_node(&node.name, op, &inputs) + } + as_op!(); } diff --git a/core/src/ops/change_axes.rs b/core/src/ops/change_axes.rs index 52fad5df56..5b79b4e187 100644 --- a/core/src/ops/change_axes.rs +++ b/core/src/ops/change_axes.rs @@ -906,7 +906,7 @@ impl TypedOp for AxisOp { Ok(Some(op)) } - fn substitute_symbols( + fn set_symbols( &self, _source: &TypedModel, node: &TypedNode, diff --git a/core/src/ops/cnn/deconv/deconv_sum.rs b/core/src/ops/cnn/deconv/deconv_sum.rs index 04479a9b2e..59b5fe5721 100644 --- a/core/src/ops/cnn/deconv/deconv_sum.rs +++ b/core/src/ops/cnn/deconv/deconv_sum.rs @@ -106,7 +106,7 @@ impl TypedOp for DeconvSum { Ok(tvec!(inputs[0].datum_type.fact(shape))) } - fn substitute_symbols( + fn set_symbols( &self, _source: &TypedModel, node: &TypedNode, diff --git a/core/src/ops/fft.rs b/core/src/ops/fft.rs index 4e55aac9ad..3da8294c8c 100644 --- a/core/src/ops/fft.rs +++ b/core/src/ops/fft.rs @@ -3,7 +3,7 @@ use num_complex::Complex; use rustfft::num_traits::{Float, FromPrimitive}; use rustfft::{FftDirection, FftNum}; use tract_data::itertools::Itertools; -use tract_ndarray::Axis; +use tract_ndarray::Axis as NdAxis; #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub struct Fft { @@ -93,6 +93,43 @@ impl TypedOp for Fft { Ok(tvec!(inputs[0].without_value())) } + fn axes_mapping( + &self, + inputs: &[&TypedFact], + _outputs: &[&TypedFact], + ) -> TractResult { + // Fft is rank-preserving but it is NOT axes-natural: two axes do + // not map 1-to-1 from input to output and must be declared as a + // separate input-only and output-only axis. + // + // - the FFT axis (`self.axis`): every output sample along it + // depends on every input sample, so the axis cannot be + // sliced or streamed. + // - the trailing complex axis (`rank - 1`): the FFT mixes the + // real and imaginary parts, so re/im do not map 1-to-1. + // + // Splitting them is exactly what makes the generic pulse fallback + // bail when asked to track a streaming axis through the FFT or + // complex axis, while every genuine batch axis stays 1-to-1 and + // is handled by the per-pulse `PulseWrappingOp`. No dedicated + // `Fft` pulsifier is needed. + let rank = inputs[0].rank(); + let complex_axis = rank - 1; + let mut axes = tvec!(); + let mut alphabet = 'a'..; + for i in 0..rank { + if i == self.axis || i == complex_axis { + axes.push(crate::axes::Axis::new(alphabet.next().unwrap(), 1, 1).input(0, i)); + axes.push(crate::axes::Axis::new(alphabet.next().unwrap(), 1, 1).output(0, i)); + } else { + axes.push( + crate::axes::Axis::new(alphabet.next().unwrap(), 1, 1).input(0, i).output(0, i), + ); + } + } + AxesMapping::new(1, 1, axes) + } + as_op!(); } @@ -168,7 +205,7 @@ impl Stft { } fft.process(&mut v); oslice - .index_axis_mut(Axis(self.axis), f) + .index_axis_mut(NdAxis(self.axis), f) .iter_mut() .zip(v.iter().flat_map(|cmpl| [cmpl.re, cmpl.im].into_iter())) .for_each(|(s, v)| *s = v); @@ -223,5 +260,42 @@ impl TypedOp for Stft { Ok(tvec!(inputs[0].datum_type.fact(shape))) } + fn axes_mapping( + &self, + inputs: &[&TypedFact], + _outputs: &[&TypedFact], + ) -> TractResult { + // Stft is NOT rank-preserving: it inserts a frame axis at + // `axis + 1`. The mapping is: + // - axes 0..self.axis (leading dims): 1-to-1 input <-> output. + // - input axis `self.axis` (the time axis) <-> output axis + // `self.axis` (now the n_frames axis -- same position, the + // dim shrinks from `T` to `(T - frame) / stride + 1`). + // - output axis `self.axis + 1` (the inserted frame axis): + // output-only, no input correspondence. + // - input axes `self.axis + 1..rank` (trailing dims incl. + // the complex pair) <-> output axes `self.axis + 2..rank+1` + // (shifted right by 1 to make room for the frame axis). + // + // Without this mapping the generic `PulseWrappingOp` fallback + // bails with "could not track pulsing axis" the moment a user + // streams a non-time axis through STFT (typical pattern: a + // batched STFT pipeline that streams the batch axis). + let in_rank = inputs[0].rank(); + let mut axes = tvec!(); + let mut alphabet = 'a'..; + for i in 0..in_rank { + let out_axis = if i <= self.axis { i } else { i + 1 }; + axes.push( + crate::axes::Axis::new(alphabet.next().unwrap(), 1, 1) + .input(0, i) + .output(0, out_axis), + ); + } + // Inserted frame axis (output-only). + axes.push(crate::axes::Axis::new(alphabet.next().unwrap(), 1, 1).output(0, self.axis + 1)); + crate::axes::AxesMapping::new(1, 1, axes) + } + as_op!(); } diff --git a/core/src/ops/konst.rs b/core/src/ops/konst.rs index a2754d50d0..eb23b38b7f 100644 --- a/core/src/ops/konst.rs +++ b/core/src/ops/konst.rs @@ -76,7 +76,7 @@ impl TypedOp for Const { Ok(tvec!((Cost::Params(self.0.datum_type().unquantized()), self.0.len().into()))) } - fn substitute_symbols( + fn set_symbols( &self, _source: &TypedModel, node: &TypedNode, diff --git a/core/src/ops/mod.rs b/core/src/ops/mod.rs index b69eefc5c8..12db1cc7df 100644 --- a/core/src/ops/mod.rs +++ b/core/src/ops/mod.rs @@ -295,7 +295,7 @@ pub trait TypedOp: /// expressions (a concrete integer is `TDim::Val(v)`; an expression /// can be any other TDim, including symbolic ones). #[allow(unused_variables)] - fn substitute_symbols( + fn set_symbols( &self, source: &TypedModel, node: &TypedNode, diff --git a/core/src/ops/nn/rms_norm.rs b/core/src/ops/nn/rms_norm.rs index b224a894c3..f5da41e94b 100644 --- a/core/src/ops/nn/rms_norm.rs +++ b/core/src/ops/nn/rms_norm.rs @@ -28,7 +28,34 @@ impl EvalOp for RmsNorm { fn eval(&self, inputs: TVec) -> TractResult> { let input = args_1!(inputs); + let in_dt = input.datum_type(); + // Fast path: F32 or F16 input where the normalised axis is the last + // (contiguous) one. Use the fused tract_linalg::rms_norm_f32 kernel + // (AVX-512 when available; scalar fallback otherwise) instead of the + // 4-call MeanOfSquares + Add + Rsqrt + Mul composition below. ~16-18x + // faster on Cascade Lake AVX-512, ~equivalent on the scalar fallback + // since the composition is also memory-bandwidth bound. + if matches!(in_dt, DatumType::F32 | DatumType::F16) + && input.rank() > 0 + && self.axis == input.rank() - 1 + { + let eps_f32: f32 = self.eps.cast_to_scalar::()?; + let mut buf = input.cast_to::()?.into_owned(); + let row_len = buf.shape()[self.axis]; + if row_len > 0 { + let n_rows: usize = buf.shape().iter().take(self.axis).product(); + let data = unsafe { buf.as_slice_mut_unchecked::() }; + let rms_norm = &tract_linalg::ops().rms_norm_f32; + for r in 0..n_rows { + let start = r * row_len; + rms_norm(&mut data[start..start + row_len], eps_f32); + } + } + return Ok(tvec![buf.cast_to_dt(in_dt)?.into_owned().into()]); + } + + // Slow path: original 4-call composition (kept for non-contiguous axes). let input_f32 = input.cast_to::()?.into_owned(); // eps inherits the input dtype from the declutter pattern (F16 when the // surrounding LayerNorm chain is F16). The MeanOfSquares + Add + Rsqrt @@ -41,7 +68,7 @@ impl EvalOp for RmsNorm { let mut a2 = Add.eval(a1.into_tvalue(), eps.into_tvalue(), DatumType::F32)?; Rsqrt {}.eval_in_place(&mut a2, None)?; let a3 = Mul.eval(a2.into_tvalue(), input_f32.into_tvalue(), DatumType::F32)?; - Ok(tvec![a3.cast_to_dt(input.datum_type())?.into_owned().into()]) + Ok(tvec![a3.cast_to_dt(in_dt)?.into_owned().into()]) } } @@ -205,4 +232,40 @@ mod tests { assert!(diff < 0.01, "lane {i}: got {} expected {}", g.to_f32(), e); } } + + /// Slow path: when the normalised axis is NOT the trailing one, the fast + /// path in `eval` (which dispatches to `tract_linalg::ops().rms_norm_f32`) + /// is skipped and the original 4-call `MeanOfSquares` + `Add` + `Rsqrt` + + /// `Mul` composition runs. Asserts the result is identical to a hand- + /// computed reference, so the slow path stays correct after the fast-path + /// addition. + #[test] + fn eval_with_non_trailing_axis_f32() { + // 2x3 input, axis=0 means we normalise across the 2 rows for each + // column independently: + // col 0: [1, 4] → mean_sq = (1 + 16) / 2 = 8.5 → 1/√8.5 + // col 1: [2, 5] → mean_sq = (4 + 25) / 2 = 14.5 → 1/√14.5 + // col 2: [3, 6] → mean_sq = (9 + 36) / 2 = 22.5 → 1/√22.5 + let input = tensor2(&[[1.0_f32, 2.0, 3.0], [4.0, 5.0, 6.0]]); + let eps = tensor0(0.0_f32).into_arc_tensor(); + let op = RmsNorm { axis: 0, eps }; + let out = op.eval(tvec!(input.into())).expect("eval should not panic"); + let out = out.into_iter().next().unwrap().into_tensor(); + assert_eq!(out.datum_type(), DatumType::F32); + assert_eq!(out.shape(), &[2, 3]); + let got = unsafe { out.as_slice_unchecked::() }; + let inv = |ms: f32| ms.sqrt().recip(); + let expected: [f32; 6] = [ + 1.0 * inv(8.5), + 2.0 * inv(14.5), + 3.0 * inv(22.5), + 4.0 * inv(8.5), + 5.0 * inv(14.5), + 6.0 * inv(22.5), + ]; + for (i, (g, e)) in got.iter().zip(expected.iter()).enumerate() { + let diff = (g - e).abs(); + assert!(diff < 1e-5, "lane {i}: got {g}, want {e}, diff {diff}"); + } + } } diff --git a/core/src/ops/scan/decluttered.rs b/core/src/ops/scan/decluttered.rs index 9d638b4cd3..53e3f148f6 100644 --- a/core/src/ops/scan/decluttered.rs +++ b/core/src/ops/scan/decluttered.rs @@ -110,18 +110,27 @@ impl Scan { // on the first call or when reset_every_turn is set; otherwise the // body input is fed from the carried hidden_state. // - // Inlining is therefore safe iff the caller does not rely on - // tract's across-call carry, which we encode as one of: + // Inlining is safe iff the caller does not rely on tract's across-call + // carry. We accept it when: // - reset_every_turn (carry is cleared every turn anyway), or - // - external_state (caller plumbs the State input every call, - // e.g. parakeet decoder). - // Otherwise (internal-managed state, e.g. DFN3 GRU under pulse), - // bail. See issue #2157. - rule_if!( - self.reset_every_turn - || self.external_state - || !self.input_mapping.iter().any(InputMapping::is_state) - ); + // - external_state (explicitly asserted, e.g. force_scan_external_state), or + // - there is no State at all, or + // - the caller manages the state: every recurrent state has a + // last-value output that reaches a model output, so the caller can + // read the updated state and feed it back across calls (e.g. DTLN / + // parakeet decoders). A pulse model with tract-managed state (e.g. + // DFN3 GRU) does NOT export its state, so it is not inlined. + // See issue #2157. + let has_state = self.input_mapping.iter().any(InputMapping::is_state); + let state_outputs: Vec<_> = self.output_mapping.iter().filter(|m| m.state).collect(); + let state_exported = has_state + && !state_outputs.is_empty() + && state_outputs.iter().all(|m| { + m.last_value_slot.is_some_and(|slot| { + Self::outlet_reaches_model_output(model, OutletId::new(node.id, slot)) + }) + }); + rule_if!(self.reset_every_turn || self.external_state || !has_state || state_exported); let mut patch = TypedModelPatch::new("Inline single loop scan"); patch.model = self.body.clone(); for (outer_wire, inner_wire) in izip!(&node.inputs, &self.body.inputs) { @@ -138,6 +147,31 @@ impl Scan { Ok(Some(patch)) } + /// True if `start` (an outlet of this Scan, e.g. a state's last-value + /// output) is, or transitively feeds, a model output — i.e. the caller can + /// observe the updated state and thread it back across calls. Used to tell a + /// caller-managed-state model (safe to inline a single-iteration Scan) from + /// a pulse model whose state tract carries internally (must not inline). + fn outlet_reaches_model_output(model: &TypedModel, start: OutletId) -> bool { + let outputs: std::collections::HashSet = model.outputs.iter().copied().collect(); + let mut seen: std::collections::HashSet = Default::default(); + let mut stack = vec![start]; + while let Some(o) = stack.pop() { + if outputs.contains(&o) { + return true; + } + if !seen.insert(o) { + continue; + } + for succ in &model.node(o.node).outputs[o.slot].successors { + for slot in 0..model.node(succ.node).outputs.len() { + stack.push(OutletId::new(succ.node, slot)); + } + } + } + false + } + fn declutter_body_axes( &self, _session: &mut OptimizerSession, @@ -894,7 +928,7 @@ impl TypedOp for Scan { Ok(None) } - fn substitute_symbols( + fn set_symbols( &self, _source: &TypedModel, node: &TypedNode, @@ -907,9 +941,9 @@ impl TypedOp for Scan { output_mapping: self .output_mapping .iter() - .map(|om| om.substitute_symbols(subs)) + .map(|om| om.set_symbols(subs)) .collect::>>()?, - body: self.body.substitute_symbols(subs)?, + body: self.body.set_symbols(subs)?, ..self.clone() }; target.wire_node(&node.name, op, &inputs) diff --git a/core/src/ops/scan/mod.rs b/core/src/ops/scan/mod.rs index 74530f8ec8..f2e8f486a0 100644 --- a/core/src/ops/scan/mod.rs +++ b/core/src/ops/scan/mod.rs @@ -52,7 +52,7 @@ impl OutputMapping { } impl OutputMapping { - pub fn substitute_symbols( + pub fn set_symbols( &self, subs: &std::collections::HashMap, ) -> TractResult> { diff --git a/core/src/ops/scan/optimized.rs b/core/src/ops/scan/optimized.rs index 05245d7a43..db1604f1ab 100644 --- a/core/src/ops/scan/optimized.rs +++ b/core/src/ops/scan/optimized.rs @@ -233,6 +233,15 @@ impl OpState for State { outputs.sort_by_key(|a| a.0); let mut outputs: TVec = outputs.into_iter().map(|(_slot, v)| v).collect(); + // The body runs the SAME plan with the SAME shapes every iteration, so + // resolve its symbols once and keep them across iters (light per-iter + // reset), and reuse one drained input buffer. This avoids the per-timestep + // symbol re-resolution + reallocation a plain `model_state.run()` would do. + // Cleared up front since the body state persists across outer Scan calls. + model_state.clear_resolved_symbols(); + let mut iter_inputs: TVec = tvec!(); + let mut symbols_resolved = false; + for i in 0..iters { *position += 1; if *position <= op.skip { @@ -240,28 +249,29 @@ impl OpState for State { } hidden_state.reverse(); - let iter_inputs: TVec = op - .input_mapping - .iter() - .enumerate() - .map(|(slot, m)| { - Ok(match m { - InputMapping::State => Some(hidden_state.pop().unwrap()), - InputMapping::Scan(info) => Some( - Self::slice_input(&inputs[slot], info.axis, i, info.chunk)? - .into_tvalue(), - ), - InputMapping::Full => Some(inputs[slot].clone()), - }) - }) - .collect::>>()? - .into_iter() - .flatten() - .collect(); - + iter_inputs.clear(); + for (slot, m) in op.input_mapping.iter().enumerate() { + iter_inputs.push(match m { + InputMapping::State => hidden_state.pop().unwrap(), + InputMapping::Scan(info) => { + Self::slice_input(&inputs[slot], info.axis, i, info.chunk)?.into_tvalue() + } + InputMapping::Full => inputs[slot].clone(), + }); + } trace!("iter_inputs #{i}: {iter_inputs:?}"); - let iter_outputs = - model_state.run(iter_inputs).with_context(|| "Evaluating inner body")?; + + // Lighter equivalent of `model_state.run(iter_inputs)`: resolve body + // symbols only on the first iteration, and reset between iters without + // discarding them. + model_state.set_inputs_drain(&mut iter_inputs).context("Setting body inputs")?; + if !symbols_resolved { + model_state.resolve_symbols_with_states()?; + symbols_resolved = true; + } + model_state.exec().with_context(|| "Evaluating inner body")?; + let iter_outputs = model_state.outputs()?; + model_state.reset_turn_keep_symbols(); trace!("iter_outputs #{i}: {iter_outputs:?}"); for (v, mapping) in iter_outputs.into_iter().zip(&op.output_mapping) { diff --git a/core/src/ops/source.rs b/core/src/ops/source.rs index 4b4c32244d..20de5a1f6e 100644 --- a/core/src/ops/source.rs +++ b/core/src/ops/source.rs @@ -66,7 +66,7 @@ impl TypedOp for TypedSource { ))) } - fn substitute_symbols( + fn set_symbols( &self, _source: &TypedModel, node: &TypedNode, diff --git a/core/src/plan.rs b/core/src/plan.rs index 4998dc9e46..216b55d4da 100644 --- a/core/src/plan.rs +++ b/core/src/plan.rs @@ -247,11 +247,27 @@ where } /// Reset wires state. pub fn reset_turn(&mut self) -> TractResult<()> { + self.reset_turn_keep_symbols(); + self.turn_state.resolved_symbols = SymbolValues::default(); + Ok(()) + } + + /// Like [`reset_turn`] but keeps the resolved symbols (and scenario). Used by + /// `Scan`/`Loop` bodies, whose shapes are constant across iterations: it lets + /// the body resolve its symbols once and skip the per-iteration re-resolution + /// the full `reset_turn` + `run` cycle would otherwise force. + pub(crate) fn reset_turn_keep_symbols(&mut self) { for node in &self.plan.order { self.turn_state.values[*node] = None; } + } + + /// Clear resolved symbols (and scenario) without touching node values. Used at + /// the start of a fresh `Scan` evaluation, since the body state persists across + /// outer calls and a previous call may have left stale symbol resolutions. + pub(crate) fn clear_resolved_symbols(&mut self) { self.turn_state.resolved_symbols = SymbolValues::default(); - Ok(()) + self.turn_state.scenario = None; } /// Reset op inner state. @@ -263,7 +279,7 @@ where Ok(()) } - fn resolve_symbols_with_states(&mut self) -> TractResult<()> { + pub(crate) fn resolve_symbols_with_states(&mut self) -> TractResult<()> { for state in self .op_states .iter_mut() @@ -470,6 +486,22 @@ where Ok(()) } + /// Like [`set_inputs`] but drains the caller's buffer (leaving it empty with + /// its capacity intact) instead of consuming it, so a repeated caller (a + /// `Scan` body loop) can reuse one allocation across iterations. + pub(crate) fn set_inputs_drain(&mut self, inputs: &mut TVec) -> TractResult<()> { + ensure!( + inputs.len() == self.model().inputs.len(), + "Wrong number of inputs for model. Expected {} got {}", + self.model().inputs.len(), + inputs.len() + ); + for (ix, t) in inputs.drain(..).enumerate() { + self.set_input(ix, t)? + } + Ok(()) + } + fn resolve(state: &mut TurnState, expression: &TDim, provided: i64) -> TractResult<()> { if let TDim::Sym(sym) = expression && state.resolved_symbols.get(sym).is_none() diff --git a/core/src/runtime.rs b/core/src/runtime.rs index 06ecbca35a..a63616a72e 100644 --- a/core/src/runtime.rs +++ b/core/src/runtime.rs @@ -105,7 +105,7 @@ pub struct DefaultRuntime; impl Runtime for DefaultRuntime { fn name(&self) -> StaticName { - Cow::Borrowed("default") + Cow::Borrowed("cpu") } fn prepare_with_options( @@ -215,7 +215,34 @@ pub fn runtimes() -> impl Iterator { inventory::iter::().filter(|rt| rt.check().is_ok()).map(|ir| ir.0) } +/// Known GPU backends, tried in order when resolving the virtual `gpu` +/// (strict) / `gpu-or-cpu` (best-effort) names. +const GPU_RUNTIME_NAMES: &[&str] = &["metal", "cuda"]; + pub fn runtime_for_name(s: &str) -> TractResult> { + // Back-compat: `default` was the original name for the CPU runtime + // before it was renamed. Keep it working as a plain alias. + let s = if s == "default" { "cpu" } else { s }; + if s == "gpu" || s == "gpu-or-cpu" { + let mut last_check_err: Option = None; + for name in GPU_RUNTIME_NAMES { + let Some(rt) = inventory::iter::().find(|rt| rt.name() == *name) + else { + continue; + }; + match rt.check() { + Ok(()) => return Ok(Some(rt.0)), + Err(e) => last_check_err = Some(e), + } + } + if s == "gpu" { + let detail = + last_check_err.map(|e| format!(" (last backend error: {e:#})")).unwrap_or_default(); + bail!("Runtime `gpu` requested but no GPU backend is available{detail}"); + } + // gpu-or-cpu: fall through to the cpu runtime. + return runtime_for_name("cpu"); + } rule_if_some!(rt = inventory::iter::().find(|rt| rt.name() == s)); rt.check()?; Ok(Some(rt.0)) diff --git a/core/src/transform.rs b/core/src/transform.rs index b83a592f81..40bf9d3be8 100644 --- a/core/src/transform.rs +++ b/core/src/transform.rs @@ -256,16 +256,16 @@ pub enum SymbolValueSpec { } #[derive(Debug, Default, serde::Deserialize)] -pub struct ConcretizeSymbolsConfig { +pub struct SetSymbolsConfig { pub values: std::collections::HashMap, } #[derive(Debug)] -struct ConcretizeSymbolsTransform(ConcretizeSymbolsConfig); +struct SetSymbolsTransform(SetSymbolsConfig); -impl ModelTransform for ConcretizeSymbolsTransform { +impl ModelTransform for SetSymbolsTransform { fn name(&self) -> StaticName { - "concretize_symbols".into() + "set_symbols".into() } fn transform(&self, model: &mut TypedModel) -> TractResult<()> { @@ -281,26 +281,26 @@ impl ModelTransform for ConcretizeSymbolsTransform { }; subs.insert(sym, dim); } - *model = model.substitute_symbols(&subs)?; + *model = model.set_symbols(&subs)?; Ok(()) } } -register_model_transform!("concretize_symbols", ConcretizeSymbolsConfig, |config| Ok(Box::new( - ConcretizeSymbolsTransform(config) +register_model_transform!("set_symbols", SetSymbolsConfig, |config| Ok(Box::new( + SetSymbolsTransform(config) ))); /// Ad-hoc fix-up for NNEF artifacts exported before Scan grew the -/// `external_state` flag (issue #2157). For every Scan in the model: -/// 1. Substitute the scan-axis symbol on the Scan input with 1 across the -/// whole model (caller is bound by the per-call seq=1 contract that -/// external state management implies). -/// 2. Set `external_state = true`. +/// `external_state` flag (issue #2157). Sets `external_state = true` on every +/// Scan, asserting that the caller plumbs initial state in and reads final +/// state out each call. Apply only when the loaded model is known to use +/// external state management, e.g. the parakeet decoder. Cheaper than +/// re-exporting cached NNEF. /// -/// After this transform, the standard declutter pipeline sees `iters == 1` -/// on each Scan and `declutter_single_loop` inlines the body. Apply only -/// when the loaded model is known to use external state management, e.g. -/// the parakeet decoder. Cheaper than re-exporting cached NNEF. +/// This does *not* touch the sequence dimension. Inlining the Scan body via +/// `declutter_single_loop` additionally requires `iters == 1`, which is the +/// caller's per-call contract — concretize it explicitly (e.g. `--set +/// TARGETS__TIME=1`), separately from this flag. #[derive(Debug)] struct ForceScanExternalState; @@ -310,22 +310,7 @@ impl ModelTransform for ForceScanExternalState { } fn transform(&self, model: &mut TypedModel) -> TractResult<()> { - use crate::ops::scan::{InputMapping, Scan}; - let mut subs: HashMap = HashMap::new(); - for node in &model.nodes { - let Some(scan) = node.op_as::() else { continue }; - for (slot, mapping) in scan.input_mapping.iter().enumerate() { - let InputMapping::Scan(info) = mapping else { continue }; - let outer = node.inputs[slot]; - let dim = &model.outlet_fact(outer)?.shape[info.axis]; - if let TDim::Sym(s) = dim { - subs.insert(s.clone(), TDim::Val(1)); - } - } - } - if !subs.is_empty() { - *model = model.substitute_symbols(&subs)?; - } + use crate::ops::scan::Scan; for node in &mut model.nodes { if let Some(scan) = node.op_as_mut::() { scan.external_state = true; @@ -362,6 +347,28 @@ register_model_transform!("select_outputs", SelectOutputsConfig, |config| Ok(Box SelectOutputsTransform(config) ))); +#[derive(Debug, serde::Deserialize, Default)] +pub struct SelectInputsConfig { + pub inputs: Vec, +} + +#[derive(Debug)] +struct SelectInputsTransform(SelectInputsConfig); + +impl ModelTransform for SelectInputsTransform { + fn name(&self) -> StaticName { + "select_inputs".into() + } + + fn transform(&self, model: &mut TypedModel) -> TractResult<()> { + model.select_inputs_by_name(self.0.inputs.iter()) + } +} + +register_model_transform!("select_inputs", SelectInputsConfig, |config| Ok(Box::new( + SelectInputsTransform(config) +))); + inventory::submit! { ModelTransformFactory { name: "f32_to_f16", diff --git a/cuda/Cargo.toml b/cuda/Cargo.toml index 95c4884a8a..0cf4234c78 100644 --- a/cuda/Cargo.toml +++ b/cuda/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tract-cuda" -version = "0.23.0-pre" +version = "0.23.1-pre" license = "MIT OR Apache-2.0" authors = [ "Louis Chouraki ", diff --git a/cuda/src/ops/quant_q81.rs b/cuda/src/ops/quant_q81.rs index 10c763d2a6..badf21cf56 100644 --- a/cuda/src/ops/quant_q81.rs +++ b/cuda/src/ops/quant_q81.rs @@ -104,7 +104,7 @@ impl TypedOp for CudaGgmlQuantQ81 { .with_context(|| format!("Error while computing facts for {:?}", self.name())) } - fn substitute_symbols( + fn set_symbols( &self, _source: &TypedModel, node: &TypedNode, diff --git a/data/Cargo.toml b/data/Cargo.toml index b4db993ed3..76b62b58a1 100644 --- a/data/Cargo.toml +++ b/data/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tract-data" -version = "0.23.0-pre" +version = "0.23.1-pre" license = "MIT OR Apache-2.0" authors = ["Mathieu Poumeyrol "] description = "Tiny, no-nonsense, self contained, TensorFlow and ONNX inference" diff --git a/doc/symbolic-shapes.md b/doc/symbolic-shapes.md index b933a80864..b83187fbdd 100644 --- a/doc/symbolic-shapes.md +++ b/doc/symbolic-shapes.md @@ -118,7 +118,7 @@ let b = model.symbols.sym("B"); let mut subs = HashMap::new(); subs.insert(s, 224.into()); // S = 224 subs.insert(b, 1.into()); // B = 1 -let model = model.substitute_symbols(&subs)?; +let model = model.set_symbols(&subs)?; ``` After this, the shapes are pure `Val(_)` and downstream passes treat @@ -175,11 +175,11 @@ The CLI does the same things via flags: - `-i N,3,224,224,f32` — set an input fact (per-input). - `--set S=224 --set B=1` — bind symbols to constants. Equivalent to - `substitute_symbols` on the library side. + `set_symbols` on the library side. - `--input-from-bundle io.npz` — derive concrete input shapes from the tensors actually present in the bundle; runs both "set input facts" and "concretise symbols" in one step. The library equivalent - is the `set_input_fact` + `substitute_symbols` pair above. + is the `set_input_fact` + `set_symbols` pair above. ## ONNX gotchas diff --git a/examples/nemo-nemotron-asr/src/main.rs b/examples/nemo-nemotron-asr/src/main.rs index 9de861874c..0585219318 100644 --- a/examples/nemo-nemotron-asr/src/main.rs +++ b/examples/nemo-nemotron-asr/src/main.rs @@ -13,13 +13,14 @@ fn argmax(slice: &[f32]) -> Option { } fn concretize_batch(mut model: Model) -> anyhow::Result { - model.transform(ConcretizeSymbols::new().value("BATCH", 1))?; + model.transform(SetSymbols::new().value("BATCH", 1))?; Ok(model) } fn remove_length_input(mut model: Model) -> anyhow::Result { model .transform(r#"{"name":"patch","body":"length = tract_core_shape_of(input_signal)[1];"}"#)?; + model.transform(r#"{"name":"select_inputs","inputs":["input_signal"]}"#)?; Ok(model) } diff --git a/examples/nemo-nemotron-streaming-asr/src/main.rs b/examples/nemo-nemotron-streaming-asr/src/main.rs index 8f07c22a5d..4ce7210796 100644 --- a/examples/nemo-nemotron-streaming-asr/src/main.rs +++ b/examples/nemo-nemotron-streaming-asr/src/main.rs @@ -104,10 +104,11 @@ impl NemotronModels { eprint!("Loading preprocessor to {}...", runtime.name()?); let mut pp = nnef.load(format!("{assets}/model/preprocessor.nnef.tgz"))?; - pp.transform(ConcretizeSymbols::new().value("BATCH", 1))?; + pp.transform(SetSymbols::new().value("BATCH", 1))?; pp.transform( r#"{"name":"patch","body":"length = tract_core_shape_of(input_signal)[1];"}"#, )?; + pp.transform(r#"{"name":"select_inputs","inputs":["input_signal"]}"#)?; pp.transform(r#"{"name":"select_outputs","outputs":["processed_signal"]}"#)?; pp.transform(Pulse::new(config.preproc_pulse.to_string()).symbol("INPUT_SIGNAL__TIME"))?; let pp_delay = pp.property("pulse.delay")?.as_slice::()?[0].to_owned() as usize; @@ -120,11 +121,12 @@ impl NemotronModels { eprint!("Loading encoder to {}...", runtime.name()?); let mut enc = nnef.load(format!("{assets}/model/encoder.p1.nnef.tgz"))?; - enc.transform(ConcretizeSymbols::new().value("BATCH", 1))?; + enc.transform(SetSymbols::new().value("BATCH", 1))?; enc.transform("transformers_detect_all")?; enc.transform( r#"{"name":"patch","body":"length = tract_core_shape_of(audio_signal)[2];"}"#, )?; + enc.transform(r#"{"name":"select_inputs","inputs":["audio_signal"]}"#)?; enc.transform(r#"{"name":"select_outputs","outputs":["outputs"]}"#)?; enc.transform(Pulse::new(config.encoder_pulse.to_string()).symbol("AUDIO_SIGNAL__TIME"))?; let enc_delay = enc.property("pulse.delay")?.as_slice::()?[0].to_owned() as usize; @@ -137,14 +139,14 @@ impl NemotronModels { eprint!("Loading decoder to {}...", runtime.name()?); let mut dec = nnef.load(format!("{assets}/model/decoder.nnef.tgz"))?; - dec.transform(ConcretizeSymbols::new().value("BATCH", 1).value("TARGETS__TIME", 1))?; + dec.transform(SetSymbols::new().value("BATCH", 1).value("TARGETS__TIME", 1))?; let decoder = runtime.prepare(dec)?; eprintln!(" done."); eprint!("Loading joint to {}...", runtime.name()?); let mut jnt = nnef.load(format!("{assets}/model/joint.nnef.tgz"))?; jnt.transform( - ConcretizeSymbols::new() + SetSymbols::new() .value("BATCH", 1) .value("ENCODER_OUTPUTS__TIME", 1) .value("DECODER_OUTPUTS__TIME", 1), diff --git a/examples/wasm-model-bench/src/bench_onnx.rs b/examples/wasm-model-bench/src/bench_onnx.rs index 906d35154b..98830d90c5 100644 --- a/examples/wasm-model-bench/src/bench_onnx.rs +++ b/examples/wasm-model-bench/src/bench_onnx.rs @@ -49,7 +49,7 @@ fn main() -> Result<()> { let symbols_env = std::env::var("TRACT_BENCH_SYMBOLS").ok().filter(|s| !s.is_empty()); // When symbols are provided, the model is symbolic-shaped and we go straight - // to TypedModel → substitute_symbols, ignoring shape_spec (input shapes will + // to TypedModel → set_symbols, ignoring shape_spec (input shapes will // be derived from the symbol substitution). When no symbols, we use the // shape_spec to pin input facts on the InferenceModel before into_typed. let typed = if let Some(symbols_str) = symbols_env { @@ -60,7 +60,7 @@ fn main() -> Result<()> { let sym = typed.symbols.sym(k); subs.insert(sym, TDim::Val(v.parse::()?)); } - typed.substitute_symbols(&subs)? + typed.set_symbols(&subs)? } else { if let Some(spec) = shape_spec.as_deref().filter(|s| *s != "-") { for (i, one) in spec.split(';').enumerate() { diff --git a/extra/Cargo.toml b/extra/Cargo.toml index 29a20712c7..d3c2a3649c 100644 --- a/extra/Cargo.toml +++ b/extra/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tract-extra" -version = "0.23.0-pre" +version = "0.23.1-pre" license = "MIT OR Apache-2.0" authors = ["Mathieu Poumeyrol "] description = "Tiny, no-nonsense, self contained, TensorFlow and ONNX inference" diff --git a/gpu/Cargo.toml b/gpu/Cargo.toml index 4c5bd73aa6..2d53fc0768 100644 --- a/gpu/Cargo.toml +++ b/gpu/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tract-gpu" -version = "0.23.0-pre" +version = "0.23.1-pre" license = "MIT OR Apache-2.0" authors = [ "Hubert de La Jonquiere ", diff --git a/gpu/src/ops/change_axes.rs b/gpu/src/ops/change_axes.rs index 847cbe7dfe..f456ecfff0 100644 --- a/gpu/src/ops/change_axes.rs +++ b/gpu/src/ops/change_axes.rs @@ -162,7 +162,7 @@ impl TypedOp for GpuAxisOp { self.inner.axes_mapping(&ref_inputs, &ref_outputs) } - fn substitute_symbols( + fn set_symbols( &self, _source: &TypedModel, node: &TypedNode, diff --git a/gpu/src/ops/copy_based.rs b/gpu/src/ops/copy_based.rs index 92170f38a0..7ad30fbdbe 100644 --- a/gpu/src/ops/copy_based.rs +++ b/gpu/src/ops/copy_based.rs @@ -3,7 +3,7 @@ //! any backend-specific arguments. use tract_core::internal::*; -use tract_core::ops::array::{MultiBroadcastTo, Slice, TypedConcat}; +use tract_core::ops::array::{MultiBroadcastTo, Pad, Slice, TypedConcat}; use tract_pulse_opl::ops::{Delay, PulsePad}; use tract_transformers::ops::dyn_kv_cache::DynKeyValueCache; @@ -38,5 +38,10 @@ pub fn try_make_copy_based_op( if let Some(op) = node.op_as::() { return Ok(Some(Box::new(super::pulse::GpuPulsePad::new(op)?))); } + if let Some(op) = node.op_as::() + && let Some(gpu) = super::pad::GpuPad::from_core(op) + { + return Ok(Some(Box::new(gpu))); + } Ok(None) } diff --git a/gpu/src/ops/mod.rs b/gpu/src/ops/mod.rs index b424f50ef9..219e33aeac 100644 --- a/gpu/src/ops/mod.rs +++ b/gpu/src/ops/mod.rs @@ -12,6 +12,7 @@ pub mod gather; pub mod gelu_approximate; pub mod iff; pub mod leaky_relu; +pub mod pad; pub mod pulse; pub mod reduce; pub mod rms_norm; diff --git a/gpu/src/ops/pad.rs b/gpu/src/ops/pad.rs new file mode 100644 index 0000000000..c223ab9511 --- /dev/null +++ b/gpu/src/ops/pad.rs @@ -0,0 +1,91 @@ +use crate::tensor::{DeviceTensorExt, IntoDevice}; +use tract_core::internal::*; +use tract_core::ops::array::{Pad, PadMode}; + +/// Constant padding via two `copy_nd`s: broadcast the pad value across the whole +/// output, then drop the input into the interior. No dedicated kernel. Reflect/ +/// Edge modes are left on the host (see [`GpuPad::from_core`]). +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct GpuPad { + pub pads: Vec<(usize, usize)>, + pub value: Arc, +} + +impl GpuPad { + /// Build from a core `Pad`, or `None` when the mode isn't `Constant`. + pub fn from_core(op: &Pad) -> Option { + let PadMode::Constant(value) = &op.mode else { return None }; + Some(Self { pads: op.pads.clone(), value: value.clone() }) + } + + fn output_shape(&self, input: &[D]) -> TVec { + input.iter().zip(&self.pads).map(|(d, (a, b))| d.clone() + *a + *b).collect() + } +} + +impl Op for GpuPad { + fn name(&self) -> StaticName { + "GpuPad".into() + } + + op_as_typed_op!(); +} + +impl EvalOp for GpuPad { + fn is_stateless(&self) -> bool { + true + } + + fn eval_with_session( + &self, + node_id: usize, + session: &TurnState, + inputs: TVec, + ) -> TractResult> { + let input_value = args_1!(inputs); + let input = input_value.to_device_tensor()?; + let dt = input.datum_type(); + let out_shape = self.output_shape(input.shape()); + + let output = + crate::session_handler::make_tensor_for_node(session, node_id, dt, &out_shape)?; + + let ctx = crate::device::get_context()?; + + // Fill the whole output with the pad value, broadcast from a scalar. + let value = self.value.cast_to_dt(dt)?.into_owned().into_device()?; + let zero_strides = vec![0isize; out_shape.len()]; + ctx.copy_nd(&value, 0, &zero_strides, &output, 0, &out_shape, output.strides())?; + + // Place the input at the interior offset. + if input.len() != 0 { + let interior: usize = self + .pads + .iter() + .enumerate() + .map(|(axis, (before, _))| before * output.strides()[axis] as usize) + .sum(); + ctx.copy_nd( + input, + 0, + input.strides(), + &output, + interior * dt.size_of(), + input.shape(), + output.strides(), + )?; + } + Ok(tvec![output.into_tensor().into_tvalue()]) + } +} + +impl TypedOp for GpuPad { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + crate::utils::facts_to_device_facts(inputs, |facts| { + Ok(tvec!(facts[0].datum_type.fact(self.output_shape(&facts[0].shape.to_tvec())))) + }) + .with_context(|| format!("Error while computing facts for {:?}", self.name())) + } + + as_op!(); +} diff --git a/gpu/src/ops/slice.rs b/gpu/src/ops/slice.rs index 925bc4bf42..aac38496bc 100644 --- a/gpu/src/ops/slice.rs +++ b/gpu/src/ops/slice.rs @@ -94,7 +94,7 @@ impl TypedOp for GpuSlice { .with_context(|| format!("Error while computing facts for {:?}", self.name())) } - fn substitute_symbols( + fn set_symbols( &self, _source: &TypedModel, node: &TypedNode, diff --git a/harness/core-proptest-pulse/src/fft.rs b/harness/core-proptest-pulse/src/fft.rs new file mode 100644 index 0000000000..fc78763b78 --- /dev/null +++ b/harness/core-proptest-pulse/src/fft.rs @@ -0,0 +1,179 @@ +use proptest::test_runner::TestCaseResult; +use tract_core::ops::fft::Fft; + +use super::*; + +/// FFT applied on a non-streaming axis must be pulsifiable: the batch +/// axes are 1-to-1 passthrough in `Fft::axes_mapping`, so the generic +/// per-pulse wrapper handles streaming any of them. Without that mapping +/// the pulse pass bails on the first FFT it sees, which blocks the entire +/// DPDFNet / DeepFilterNet streaming family (STFT lowers to STFT + per- +/// frame FFT, and the frame axis is the streaming one but the FFT axis +/// is the per-frame frequency axis: distinct, so streaming is sound). +/// +/// Setup: input is rank-4 (batch, T_stream, fft_size, 2). T_stream is +/// the streaming axis; FFT runs on `fft_size`; trailing 2 holds +/// (re, im). Pulse processes one chunk along T_stream at a time. +fn fft_on_non_streaming_axis(input_len: usize, pulse: usize, fft_size: usize) -> TestCaseResult { + fft_inner(input_len, pulse, fft_size, /* inverse = */ false, /* rank = */ 4) +} + +fn ifft_on_non_streaming_axis(input_len: usize, pulse: usize, fft_size: usize) -> TestCaseResult { + fft_inner(input_len, pulse, fft_size, /* inverse = */ true, /* rank = */ 4) +} + +/// Generic harness: build a model where the streaming axis is `1`, the +/// FFT axis is `rank - 2`, and the trailing axis (rank-1) is the +/// (re, im) pair. Rank 3 = (stream, fft, 2); rank 4 = (batch=1, stream, +/// fft, 2); rank 5 = (B=1, C=1, stream, fft, 2). All three are +/// rank-preserving so axes_mapping::natural lines them up identically. +fn fft_inner( + input_len: usize, + pulse: usize, + fft_size: usize, + inverse: bool, + rank: usize, +) -> TestCaseResult { + assert!(rank >= 3, "rank < 3 has no room for stream + fft + complex axes"); + let mut model = TypedModel::default(); + let s = model.symbols.sym("S"); + // Build the symbolic input shape: leading 1s, then (S, fft_size, 2). + let mut sym_shape: Vec = (0..rank - 3).map(|_| 1.to_dim()).collect(); + sym_shape.push(s.clone().into()); + sym_shape.push(fft_size.to_dim()); + sym_shape.push(2.to_dim()); + let a = model.add_source("a", f32::fact(&*sym_shape)).unwrap(); + // FFT axis: the second-to-last (just before the (re, im) pair). + let fft_axis = rank - 2; + model.wire_node("fft", Fft { axis: fft_axis, inverse }, &[a]).unwrap(); + model.auto_outputs().unwrap(); + + // Concrete input: leading 1s, then (input_len, fft_size, 2). + let mut shape = vec![1usize; rank - 3]; + shape.push(input_len); + shape.push(fft_size); + shape.push(2); + let input: ArrayD = ArrayD::from_shape_fn(shape, |idx| { + // Some non-trivial pattern so the FFT outputs aren't all zeros. + let stream_ix = idx[rank - 3]; + let freq_ix = idx[rank - 2]; + let cmplx_ix = idx[rank - 1]; + (stream_ix * fft_size * 2 + freq_ix * 2 + cmplx_ix) as f32 * 0.01 + }); + // Streaming axis = rank - 3 (the one we labelled S). + proptest_regular_against_pulse(model, pulse, input, rank - 3) +} + +#[test] +fn fft_pulse_smoke_8_pulse2_size4() { + fft_on_non_streaming_axis(8, 2, 4).unwrap(); +} + +#[test] +fn fft_pulse_smoke_4_pulse1_size8() { + fft_on_non_streaming_axis(4, 1, 8).unwrap(); +} + +#[test] +fn fft_pulse_smoke_6_pulse3_size4() { + fft_on_non_streaming_axis(6, 3, 4).unwrap(); +} + +/// Inverse FFT: same op with `inverse: true`. Verifies the natural +/// axes-mapping applies regardless of FFT direction. +#[test] +fn ifft_pulse_smoke_8_pulse2_size4() { + ifft_on_non_streaming_axis(8, 2, 4).unwrap(); +} + +/// Rank-3 input: just (stream, fft, 2), no batch / channel dims. +#[test] +fn fft_pulse_rank3() { + fft_inner(8, 2, 4, false, 3).unwrap(); +} + +/// Rank-5 input: two extra leading singleton dims before the stream +/// axis (the DPDFNet `df_op` shape is similar: `(B, C, T, F, 2)`). +#[test] +fn fft_pulse_rank5() { + fft_inner(8, 2, 4, false, 5).unwrap(); +} + +/// Stacked FFT then inverse FFT on the same axis: tract should track +/// the streaming axis through both ops via the natural mapping. +#[test] +fn fft_ifft_roundtrip_pulse() { + let mut model = TypedModel::default(); + let s = model.symbols.sym("S"); + let a = model.add_source("a", f32::fact(dims!(1, s, 4, 2))).unwrap(); + let fwd = model.wire_node("fft", Fft { axis: 2, inverse: false }, &[a]).unwrap(); + model.wire_node("ifft", Fft { axis: 2, inverse: true }, &fwd).unwrap(); + model.auto_outputs().unwrap(); + + let input: ArrayD = + ArrayD::from_shape_fn(vec![1, 8, 4, 2], |idx| (idx[1] * 8 + idx[2] * 2 + idx[3]) as f32); + proptest_regular_against_pulse(model, 2, input, 1).unwrap(); +} + +/// Streaming on the FFT axis itself must be rejected: per-pulse FFT on +/// the FFT axis is meaningless (FFT needs every sample on that axis at +/// once). `Fft::axes_mapping` declares the FFT axis as input-only (it +/// does not map to any output axis), so the generic pulse fallback can +/// not track the streaming axis through it and bails. No dedicated +/// pulsifier is involved. +#[test] +fn fft_pulse_on_fft_axis_errors() { + let mut model = TypedModel::default(); + let s = model.symbols.sym("S"); + // Streaming on axis 1, FFT also on axis 1 -> nonsense. + let a = model.add_source("a", f32::fact(dims!(1, s, 2))).unwrap(); + model.wire_node("fft", Fft { axis: 1, inverse: false }, &[a]).unwrap(); + model.auto_outputs().unwrap(); + let err = PulsedModel::new(&model, s.clone(), &2.to_dim()).unwrap_err(); + let msg = format!("{err:#}"); + assert!( + msg.contains("could not track pulsing axis"), + "expected the generic pulse fallback to refuse tracking the FFT \ + axis, got: {msg}" + ); +} + +/// Streaming on the trailing (re, im) axis is structurally +/// impossible: `Fft::output_facts` already rejects any input whose +/// trailing axis isn't `2`, so a symbolic trailing dim trips the +/// typed-model build long before the pulsifier runs. Lock in that +/// "earlier layer rejects this" contract. +#[test] +fn fft_pulse_on_complex_axis_errors() { + let mut model = TypedModel::default(); + let s = model.symbols.sym("S"); + // Streaming on the last axis (the complex pair). + let a = model.add_source("a", f32::fact(dims!(1, 4, s))).unwrap(); + let err = model.wire_node("fft", Fft { axis: 1, inverse: false }, &[a]).unwrap_err(); + let msg = format!("{err:#}"); + assert!( + msg.contains("inner (last) dimension to be 2"), + "expected the typed-model rejection of a symbolic trailing \ + axis, got: {msg}" + ); +} + +proptest! { + #[test] + fn proptest_fft_pulse( + input_len in 1usize..16, + pulse in 1usize..4, + fft_size in proptest::sample::select(vec![2usize, 4, 8, 16]), + ) { + fft_on_non_streaming_axis(input_len, pulse, fft_size)? + } + + #[test] + fn proptest_ifft_pulse( + input_len in 1usize..16, + pulse in 1usize..4, + fft_size in proptest::sample::select(vec![2usize, 4, 8, 16]), + ) { + ifft_on_non_streaming_axis(input_len, pulse, fft_size)? + } +} diff --git a/harness/core-proptest-pulse/src/lib.rs b/harness/core-proptest-pulse/src/lib.rs index faaccb07ae..e89128bfd0 100644 --- a/harness/core-proptest-pulse/src/lib.rs +++ b/harness/core-proptest-pulse/src/lib.rs @@ -25,7 +25,9 @@ mod deconv; mod delay_plus_downsample; mod delay_plus_pool; mod einsum; +mod fft; mod pad_plus_conv; +mod stft; #[allow(dead_code)] fn setup_test_logger() { @@ -47,7 +49,7 @@ fn proptest_regular_against_pulse( let subs = std::collections::HashMap::from([(s.clone(), tract_data::prelude::TDim::Val(len as i64))]); - let concrete = model.clone().substitute_symbols(&subs).unwrap(); + let concrete = model.clone().set_symbols(&subs).unwrap(); if concrete.nodes.iter().any(|n| n.outputs.iter().any(|o| o.fact.shape.volume().is_zero())) { return Err(TestCaseError::reject("too short input")); } diff --git a/harness/core-proptest-pulse/src/stft.rs b/harness/core-proptest-pulse/src/stft.rs new file mode 100644 index 0000000000..f3ed69f576 --- /dev/null +++ b/harness/core-proptest-pulse/src/stft.rs @@ -0,0 +1,65 @@ +use proptest::test_runner::TestCaseResult; +use tract_core::ops::fft::Stft; + +use super::*; + +/// STFT applied with the streaming axis distinct from the STFT axis +/// must be pulsifiable: every non-STFT axis is a 1-to-1 passthrough +/// once `Stft::axes_mapping` declares the relationship (input axis +/// `op.axis` maps to output `op.axis` as `n_frames`; output `op.axis + +/// 1` is the inserted frame axis; the rest shift naturally). Without +/// the mapping the pulse pass bails with "could not track pulsing +/// axis" the moment a batched STFT pipeline streams its batch axis. +/// +/// Setup: input is rank-3 `(B_stream, T, 2)`. B_stream is the +/// streaming axis (axis 0); STFT runs on the T axis (axis 1); the +/// trailing 2 holds (re, im). One pulse = one batch element; tract +/// runs the full-length STFT inside each pulse. +fn stft_on_non_stft_axis( + batch_len: usize, + pulse: usize, + time_len: usize, + frame: usize, + stride: usize, +) -> TestCaseResult { + let mut model = TypedModel::default(); + let s = model.symbols.sym("S"); + let a = model.add_source("a", f32::fact(dims!(s, time_len, 2))).unwrap(); + model.wire_node("stft", Stft { axis: 1, frame, stride, window: None }, &[a]).unwrap(); + model.auto_outputs().unwrap(); + + let input: ArrayD = ArrayD::from_shape_fn(vec![batch_len, time_len, 2], |idx| { + (idx[0] * time_len * 2 + idx[1] * 2 + idx[2]) as f32 * 0.01 + }); + proptest_regular_against_pulse(model, pulse, input, 0) +} + +#[test] +fn stft_pulse_batch_axis_smoke_4_pulse2_t8_frame4_stride2() { + stft_on_non_stft_axis(4, 2, 8, 4, 2).unwrap(); +} + +#[test] +fn stft_pulse_batch_axis_smoke_3_pulse1_t6_frame3_stride1() { + stft_on_non_stft_axis(3, 1, 6, 3, 1).unwrap(); +} + +#[test] +fn stft_pulse_batch_axis_smoke_2_pulse2_t12_frame4_stride4() { + stft_on_non_stft_axis(2, 2, 12, 4, 4).unwrap(); +} + +proptest! { + #[test] + fn proptest_stft_pulse_batch_axis( + batch_len in 1usize..6, + pulse in 1usize..3, + time_len in 4usize..16, + frame in proptest::sample::select(vec![2usize, 4]), + stride in proptest::sample::select(vec![1usize, 2]), + ) { + // Skip frame > time_len -- the STFT would produce 0 frames. + prop_assume!(time_len >= frame); + stft_on_non_stft_axis(batch_len, pulse, time_len, frame, stride)? + } +} diff --git a/harness/nemotron-speech-streaming-en-0.6b/ci.sh b/harness/nemotron-speech-streaming-en-0.6b/ci.sh index 5b3697e60b..35e30c0272 100755 --- a/harness/nemotron-speech-streaming-en-0.6b/ci.sh +++ b/harness/nemotron-speech-streaming-en-0.6b/ci.sh @@ -25,11 +25,13 @@ do nnef_file=$MODEL.$m.nnef.tgz fi # Decoder is stepped one token per call by the caller (external state - # carry); force the Scan op into single-iter inlining so the LSTM body - # lands on the GPU instead of bouncing through CPU each step. + # carry): assert the external_state flag and concretize the seq symbol + # to 1 so the Scan inlines and the LSTM body lands on the GPU instead of + # bouncing through CPU each step. set_symbols RON must stay space-free + # ($extra_transforms is passed unquoted). extra_transforms="" if [ "$m" = "decoder" ]; then - extra_transforms="-t force_scan_external_state" + extra_transforms='-t force_scan_external_state -t set_symbols(values:{"TARGETS__TIME":1})' fi $CACHE_FILE \ $S3DIR/$nnef_file \ @@ -46,16 +48,18 @@ model_prefix=$MODELS/$S3DIR/$MODEL # Check that the patch transform eliminates all Iff nodes, # and that select_outputs can reduce the model to a single output $TRACT_RUN $model_prefix.preprocessor.nnef.tgz \ - -t 'concretize_symbols(values: {"BATCH": 1})' \ + -t 'set_symbols(values: {"BATCH": 1})' \ -t 'patch(body: "length = tract_core_shape_of(input_signal)[1];")' \ + -t 'select_inputs(inputs: ["input_signal"])' \ -t 'select_outputs(outputs: ["processed_signal"])' \ dump -q \ --assert-op-count Iff 0 # Check that the preprocessor can be pulsified $TRACT_RUN $model_prefix.preprocessor.nnef.tgz \ - -t 'concretize_symbols(values: {"BATCH": 1})' \ + -t 'set_symbols(values: {"BATCH": 1})' \ -t 'patch(body: "length = tract_core_shape_of(input_signal)[1];")' \ + -t 'select_inputs(inputs: ["input_signal"])' \ -t 'select_outputs(outputs: ["processed_signal"])' \ -t 'pulse(symbol: Some("INPUT_SIGNAL__TIME"), pulse: "4800")' \ dump -q @@ -79,14 +83,14 @@ do *) continue;; esac $TRACT_RUN $model_prefix.preprocessor.nnef.tgz $rt \ - -t 'concretize_symbols(values: {"BATCH": 1})' \ + -t 'set_symbols(values: {"BATCH": 1})' \ -t 'patch(body: "length = tract_core_shape_of(input_signal)[1];")' \ -t 'select_outputs(outputs: ["processed_signal"])' \ -t 'pulse(symbol: Some("INPUT_SIGNAL__TIME"), pulse: "4800")' \ dump -q $pp_assert $TRACT_RUN $model_prefix.encoder.p1.nnef.tgz $rt \ --nnef-tract-transformers \ - -t 'concretize_symbols(values: {"BATCH": 1})' \ + -t 'set_symbols(values: {"BATCH": 1})' \ -t 'patch(body: "length = tract_core_shape_of(audio_signal)[2];")' \ -t 'select_outputs(outputs: ["outputs"])' \ -t 'pulse(symbol: Some("AUDIO_SIGNAL__TIME"), pulse: "112")' \ @@ -99,8 +103,9 @@ done # must be 14 * 8 = 112 audio frames. $TRACT_RUN $model_prefix.encoder.p1.nnef.tgz \ --nnef-tract-transformers \ - -t 'concretize_symbols(values: {"BATCH": 1})' \ + -t 'set_symbols(values: {"BATCH": 1})' \ -t 'patch(body: "length = tract_core_shape_of(audio_signal)[2];")' \ + -t 'select_inputs(inputs: ["audio_signal"])' \ -t 'select_outputs(outputs: ["outputs"])' \ -t 'pulse(symbol: Some("AUDIO_SIGNAL__TIME"), pulse: "112")' \ dump -q @@ -110,8 +115,9 @@ $TRACT_RUN $model_prefix.encoder.p1.nnef.tgz \ # and the output comparison is trimmed accordingly. $TRACT_RUN $model_prefix.encoder.p1.nnef.tgz \ --nnef-tract-transformers \ - -t 'concretize_symbols(values: {"BATCH": 1})' \ + -t 'set_symbols(values: {"BATCH": 1})' \ -t 'patch(body: "length = tract_core_shape_of(audio_signal)[2];")' \ + -t 'select_inputs(inputs: ["audio_signal"])' \ -t 'select_outputs(outputs: ["outputs"])' \ -t 'pulse(symbol: Some("AUDIO_SIGNAL__TIME"), pulse: "112")' \ run \ diff --git a/harness/nnef-test-cases/conv-then-shape-of-mask/runme.sh b/harness/nnef-test-cases/conv-then-shape-of-mask/runme.sh index 2faa03e324..07af5af9b7 100755 --- a/harness/nnef-test-cases/conv-then-shape-of-mask/runme.sh +++ b/harness/nnef-test-cases/conv-then-shape-of-mask/runme.sh @@ -7,7 +7,7 @@ set -ex # Batch mode: S=8 -> conv output T = 1 + 8/2 = 5 frames; the add is fine. $TRACT_RUN --nnef-tract-core . \ - -t 'concretize_symbols(values: {"S": 8})' \ + -t 'set_symbols(values: {"S": 8})' \ run --allow-random-input -q # Streaming compare: pulse=4 -> conv produces 2 frames/step. The diff --git a/harness/nnef-test-cases/scan-body-stream-drift/graph.nnef b/harness/nnef-test-cases/scan-body-stream-drift/graph.nnef new file mode 100644 index 0000000000..6e9ea2c484 --- /dev/null +++ b/harness/nnef-test-cases/scan-body-stream-drift/graph.nnef @@ -0,0 +1,35 @@ +version 1.0; + +extension tract_registry tract_core; +extension tract_symbol S; + +# Path 1 construction (scan axis == streaming axis), with the streaming +# symbol on a Full body source. The Scan iterates the streaming input on +# axis 0 (chunk 1). A second Full body slot receives an internally +# constructed zero-valued tensor whose shape mirrors the streaming +# input ([S, 4]); the Pulsifier walks both wires and substitutes +# `S -> pulse` in the outer scan op. The fix in pulse/ops/scan.rs +# substitutes the same symbol inside the body so the Full slot's body +# source goes from [S, 4] to [pulse, 4]; without the fix the body's +# source fact stays symbolic in S. + +fragment scan_body_fn( + scan_x: tensor, + full_y: tensor +) -> ( y_out: tensor ) { + full_first = slice(full_y, axes = [0], begin = [0], end = [1]); + y_out = add(scan_x, full_first); +} + +graph scan_body_stream_drift(input) -> (output) +{ + input = external(shape = [S, 4]); + full_zeros = mul(input, 0.0); + output = tract_core_scan( + body = "scan_body_fn", + scan = [("scan_x", input, 0, 1)], + full = [("full_y", full_zeros)], + state = [], + output = [("y_out", "full", 0, 1)] + ); +} diff --git a/harness/nnef-test-cases/scan-body-stream-drift/runme.sh b/harness/nnef-test-cases/scan-body-stream-drift/runme.sh new file mode 100755 index 0000000000..3a58360eb3 --- /dev/null +++ b/harness/nnef-test-cases/scan-body-stream-drift/runme.sh @@ -0,0 +1,29 @@ +#!/bin/sh + +cd `dirname $0` +set -ex + +: ${TRACT_RUN:=cargo run -p tract-cli $CARGO_OPTS --} + +# Batch run and streaming compare: smoke test that pulsified model +# loads and produces the same output as the batched one. +$TRACT_RUN --nnef-tract-core . \ + -t 'set_symbols(values: {"S": 8})' \ + run --allow-random-input -q + +$TRACT_RUN --nnef-tract-core . --pulse 4 compare \ + --stream --allow-random-input -q + +# Body / outer fact consistency: the Pulsifier substitutes the stream +# symbol S in the outer wire facts; it must do the same on the Scan +# body source facts. Without the fix in pulse/src/ops/scan.rs the Full +# slot's body source keeps `S,4,F32` while the outer wire is +# `4,4,F32 [pulse axis:0 ...]`; assert here that the body source is +# a concrete shape (no residual stream symbol). +DUMP=$($TRACT_RUN --nnef-tract-core . --pulse 4 --pass pulse dump 2>&1 | sed 's/\x1b\[[0-9;]*m//g') +if echo "$DUMP" | grep -A 1 '\[loop\].*Source.*full_y' | grep -qE '\bS\b'; then + echo "ERROR: Scan body source for 'full_y' still carries the stream symbol after pulse pass." + echo " Pulsifier did not substitute it in the Scan body." + echo "$DUMP" | grep -A 1 '\[loop\].*Source.*full_y' + exit 1 +fi diff --git a/harness/nnef-test-cases/slice-of-static-with-streaming-size/runme.sh b/harness/nnef-test-cases/slice-of-static-with-streaming-size/runme.sh index 526d204fcb..cca6b81524 100755 --- a/harness/nnef-test-cases/slice-of-static-with-streaming-size/runme.sh +++ b/harness/nnef-test-cases/slice-of-static-with-streaming-size/runme.sh @@ -7,7 +7,7 @@ set -ex # Batch mode: concretize S=8 -> pe_table[0:8, :] + input[0:8, :] $TRACT_RUN --nnef-tract-core . \ - -t 'concretize_symbols(values: {"S": 8})' \ + -t 'set_symbols(values: {"S": 8})' \ run --allow-random-input -q # Streaming compare: pulse=4. Each step slices pe_table[0:4, :] (constant diff --git a/harness/parakeet-tdt-600m-v3/ci.sh b/harness/parakeet-tdt-600m-v3/ci.sh index b891a0dbbb..6293bacddf 100755 --- a/harness/parakeet-tdt-600m-v3/ci.sh +++ b/harness/parakeet-tdt-600m-v3/ci.sh @@ -22,13 +22,15 @@ do nnef_file=nvidia--parakeet-tdt-0.6b-v3-f32f32.$m.nnef.tgz fi # decoder LSTM is externally state-managed (caller plumbs state_0/ - # state_1 every run): force the external_state flag on Scans - # pre-optimisation so declutter_single_loop inlines them, and assert - # no Scan survives. Cached NNEF predates the flag — see #2157. + # state_1 every run) and stepped one token per call: assert the + # external_state flag on Scans (cached NNEF predates it — see #2157) + # and concretize the seq symbol to 1, so declutter_single_loop inlines + # them. Assert no Scan survives. set_symbols RON must stay space-free + # ($extra_transform is passed unquoted). extra_transform="" extra_assert="" if [ "$m" = "decoder" ]; then - extra_transform="-t force_scan_external_state" + extra_transform='-t force_scan_external_state -t set_symbols(values:{"T":1})' extra_assert="--assert-op-count Scan 0" fi $CACHE_FILE \ diff --git a/harness/tf-inceptionv3/Cargo.toml b/harness/tf-inceptionv3/Cargo.toml index 4d5a56dd56..be5be5472e 100644 --- a/harness/tf-inceptionv3/Cargo.toml +++ b/harness/tf-inceptionv3/Cargo.toml @@ -9,9 +9,6 @@ edition = "2024" image.workspace = true tract-tensorflow.workspace = true -[features] -conform = [ "tract-tensorflow/conform" ] - [dev-dependencies] criterion.workspace = true dinghy-test.workspace = true diff --git a/harness/tf-inceptionv3/benches/inceptionv3.rs b/harness/tf-inceptionv3/benches/inceptionv3.rs index 649bef25fd..8fc988fd70 100644 --- a/harness/tf-inceptionv3/benches/inceptionv3.rs +++ b/harness/tf-inceptionv3/benches/inceptionv3.rs @@ -16,25 +16,6 @@ pub fn hopper() -> path::PathBuf { test_project_path().join(HOPPER) } -#[cfg(feature = "conform")] -fn dummy(_bencher: &mut Criterion) { - tract_tensorflow::conform::tf::for_path(tf_inceptionv3::inception_v3_2016_08_28_frozen()) - .unwrap(); -} - -#[cfg(feature = "conform")] -fn tf(bencher: &mut Criterion) { - let mut tf = - tract_tensorflow::conform::tf::for_path(tf_inceptionv3::inception_v3_2016_08_28_frozen()) - .unwrap(); - let input = tf_inceptionv3::load_image(hopper()); - bencher.bench_function("tensorflow", move |b| { - b.iter(|| { - tf.run(vec![("input", input.clone())], "InceptionV3/Predictions/Reshape_1").unwrap() - }) - }); -} - fn tract(bencher: &mut Criterion) { let mut tfd = tensorflow().model_for_path(tf_inceptionv3::inception_v3_2016_08_28_frozen()).unwrap(); @@ -46,11 +27,6 @@ fn tract(bencher: &mut Criterion) { pub fn benches() { let mut criterion: Criterion = Criterion::default().sample_size(3).configure_from_args(); - #[cfg(feature = "conform")] - { - dummy(&mut criterion); - tf(&mut criterion); - } tract(&mut criterion); } criterion_main!(benches); diff --git a/hir/Cargo.toml b/hir/Cargo.toml index 4dde2037a3..0274739ab0 100644 --- a/hir/Cargo.toml +++ b/hir/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tract-hir" -version = "0.23.0-pre" +version = "0.23.1-pre" license = "MIT OR Apache-2.0" authors = ["Mathieu Poumeyrol "] description = "Tiny, no-nonsense, self contained, TensorFlow and ONNX inference" diff --git a/libcli/Cargo.toml b/libcli/Cargo.toml index 157002f031..a3ea505f63 100644 --- a/libcli/Cargo.toml +++ b/libcli/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tract-libcli" -version = "0.23.0-pre" +version = "0.23.1-pre" license = "MIT OR Apache-2.0" authors = ["Mathieu Poumeyrol "] description = "Tiny, no-nonsense, self contained, TensorFlow and ONNX inference" @@ -35,8 +35,8 @@ tract-transformers = { workspace = true, optional = true } tract-metal = { workspace = true } [target.'cfg(any(target_os = "linux", target_os = "windows"))'.dependencies] -cudarc.workspace = true -tract-cuda.workspace = true +cudarc = { workspace = true, optional = true } +tract-cuda = { workspace = true, optional = true } [features] default = ["transformers"] @@ -45,3 +45,8 @@ hir = [] onnx = ["tract-onnx"] complex = ["tract-core/complex"] transformers = ["tract-transformers"] +# Marker / dependency selector for cudarc-backed code paths +# (`--cuda-gpu-trace`). Enabled transitively by tract-cli's cuda-XXXXX +# features. Always pulled together — picking one selects the cudarc API +# binding; this feature wires the optional dep in. +cuda = ["dep:cudarc", "dep:tract-cuda"] diff --git a/libcli/src/lib.rs b/libcli/src/lib.rs index 5527554d58..3246b5d4a3 100644 --- a/libcli/src/lib.rs +++ b/libcli/src/lib.rs @@ -14,7 +14,7 @@ pub mod time; use tract_core::internal::*; #[allow(unused_imports)] -#[cfg(any(target_os = "linux", target_os = "windows"))] +#[cfg(all(any(target_os = "linux", target_os = "windows"), feature = "cuda"))] use tract_cuda::utils::ensure_cuda_runtime_dependencies; pub fn capture_gpu_trace(matches: &clap::ArgMatches, func: F) -> TractResult<()> @@ -45,7 +45,7 @@ where bail!("`--metal-gpu-trace` present but it is only available on MacOS and iOS") } } else if matches.get_flag("cuda-gpu-trace") { - #[cfg(any(target_os = "linux", target_os = "windows"))] + #[cfg(all(any(target_os = "linux", target_os = "windows"), feature = "cuda"))] { ensure_cuda_runtime_dependencies( "`--cuda-gpu-trace` present but no CUDA installation has been found", @@ -53,9 +53,12 @@ where let _prof = cudarc::driver::safe::Profiler::new()?; func() } - #[cfg(not(any(target_os = "linux", target_os = "windows")))] + #[cfg(not(all(any(target_os = "linux", target_os = "windows"), feature = "cuda")))] { - bail!("`--cuda-gpu-trace` present but it is only available on Linux and Windows") + bail!( + "`--cuda-gpu-trace` present but tract was not built with CUDA support \ + (re-build on Linux/Windows with one of the cuda-XXXXX features)" + ) } } else { func() diff --git a/linalg/Cargo.toml b/linalg/Cargo.toml index 68955ac179..d29cfa27ae 100644 --- a/linalg/Cargo.toml +++ b/linalg/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tract-linalg" -version = "0.23.0-pre" +version = "0.23.1-pre" license = "MIT OR Apache-2.0" authors = ["Mathieu Poumeyrol "] description = "Tiny, no-nonsense, self contained, TensorFlow and ONNX inference" @@ -83,6 +83,10 @@ harness = false name = "mm_for_asr_am" harness = false +[[bench]] +name = "qmmm_i8" +harness = false + [[bench]] name = "hardswish" harness = false @@ -103,6 +107,22 @@ harness = false name = "softmax" harness = false +[[bench]] +name = "activations_avx512_f16" +harness = false + +[[bench]] +name = "erf" +harness = false + +[[bench]] +name = "rms_norm" +harness = false + +[[bench]] +name = "activations_avx512_fp16" +harness = false + [[bench]] bench = false name = "arm64simd" @@ -136,6 +156,14 @@ harness = false name = "avx512_zombies" harness = false +[[bench]] +name = "activations_avx512" +harness = false + [[bench]] name = "wasm" harness = false + +[[bench]] +name = "vnni_i32" +harness = false diff --git a/linalg/arm64/arm64simd/arm64simd_mmm_i32_8x8_dot.S.j2 b/linalg/arm64/arm64simd/arm64simd_mmm_i32_8x8_dot.S.j2 new file mode 100644 index 0000000000..0a63a74edf --- /dev/null +++ b/linalg/arm64/arm64simd/arm64simd_mmm_i32_8x8_dot.S.j2 @@ -0,0 +1,235 @@ +// vim: ft=arm + +// C tile regs: +// - x19-x29 to preserve (but x19, x28, x29 not used) +// - d8..d15 to preserve +// - v16 to v31, no need to preserve +// +// v16[0] v18[0] v20[0] v22[0] v24[0] v26[0] v28[0] v30[0] +// v16[1] v18[1] +// v16[2] v18[2] +// v16[3] v18[3] +// +// v17[0] v19[0] v21[0] v23[0] v25[0] v27[0] v29[0] v31[0] +// v17[1] v19[1] +// v17[2] v19[2] +// v17[3] v19[3] + +// no preservation either for v0-v7... +// packed A buffering (2x8 values): alternating v0, v1 with v2, v3 +// packed B buffering (2x8 values): alternating v4, v5 with v6, v7 + +.text +.align 4 + +.cpu generic+fp+simd+dotprod +.global {{G}}arm64simd_mmm_i32_8x8_dot_{{suffix}} +{{G}}arm64simd_mmm_i32_8x8_dot_{{suffix}}: + +/* + prfm pldl1keep, [x1] + prfm pldl1keep, [x2] +*/ + stp x20, x21, [sp, #-16]! + stp x22, x23, [sp, #-16]! + stp x24, x25, [sp, #-16]! + stp x26, x27, [sp, #-16]! + + stp d8, d9, [sp, #-16]! + stp d10, d11, [sp, #-16]! + stp d12, d13, [sp, #-16]! + stp d14, d15, [sp, #-16]! + +{% include "dispatcher.j2" %} + +.add_mat_mul: + ldp x2, x4, [x0, #24] // b, packing + ldp x3, x1, [x0, #8] // k, a + + cmp x3, #0 + beq .non_linear_loop + + cmp x4, #1 + beq .packed_packed_loop_1_i8i8 + +.packed_packed_loop_1: + + ld1 { v0.4s, v1.4s }, [ x1 ], #32 + ld1 { v4.4s, v5.4s }, [ x2 ], #32 + + mla v16.4s, v0.4s, v4.s[0] + mla v17.4s, v1.4s, v4.s[0] + mla v18.4s, v0.4s, v4.s[1] + mla v19.4s, v1.4s, v4.s[1] + + mla v20.4s, v0.4s, v4.s[2] + mla v21.4s, v1.4s, v4.s[2] + mla v22.4s, v0.4s, v4.s[3] + mla v23.4s, v1.4s, v4.s[3] + + mla v24.4s, v0.4s, v5.s[0] + mla v25.4s, v1.4s, v5.s[0] + mla v26.4s, v0.4s, v5.s[1] + mla v27.4s, v1.4s, v5.s[1] + + mla v28.4s, v0.4s, v5.s[2] + mla v29.4s, v1.4s, v5.s[2] + mla v30.4s, v0.4s, v5.s[3] + mla v31.4s, v1.4s, v5.s[3] + + subs x3, x3, #1 + bne .packed_packed_loop_1 + + b .non_linear_loop + +.packed_packed_loop_1_i8i8: + // PackedI8K4 (K=4-inner, r=8): per 4-K block, A is m0-3 (v0) / m4-7 (v1), + // B is n0-3 (v4) / n4-7 (v5), each lane a 4xi8 group. SDOT by-element dots + // a B column's 4 K against all 4 m rows of an A half. Same v16..v31 tile + // layout as the SMLAL kernel: v[16 + n*2 + m_half] = C[m_half*4..][n]. + ld1 { v0.16b, v1.16b }, [ x1 ], #32 + ld1 { v4.16b, v5.16b }, [ x2 ], #32 + + sdot v16.4s, v0.16b, v4.4b[0] + sdot v17.4s, v1.16b, v4.4b[0] + sdot v18.4s, v0.16b, v4.4b[1] + sdot v19.4s, v1.16b, v4.4b[1] + sdot v20.4s, v0.16b, v4.4b[2] + sdot v21.4s, v1.16b, v4.4b[2] + sdot v22.4s, v0.16b, v4.4b[3] + sdot v23.4s, v1.16b, v4.4b[3] + + sdot v24.4s, v0.16b, v5.4b[0] + sdot v25.4s, v1.16b, v5.4b[0] + sdot v26.4s, v0.16b, v5.4b[1] + sdot v27.4s, v1.16b, v5.4b[1] + sdot v28.4s, v0.16b, v5.4b[2] + sdot v29.4s, v1.16b, v5.4b[2] + sdot v30.4s, v0.16b, v5.4b[3] + sdot v31.4s, v1.16b, v5.4b[3] + + subs x3, x3, #4 + bgt .packed_packed_loop_1_i8i8 + + b .non_linear_loop + +{% set from = 16 %}{% set to = 31 %}{% include "arm64simd_mmm_i32_scalars.j2" %} +{% set mr = 8 %}{% set from = 16 %}{% set to = 31 %}{% include "arm64simd_mmm_i32_per_rows.j2" %} +{% set mr = 8 %}{% set from = 16 %}{% set to = 31 %}{% include "arm64simd_mmm_i32_per_cols.j2" %} +{% set from = 16 %}{% set to = 31 %}{% include "arm64simd_mmm_load_tile.j2" %} + +.add_unicast: + ldp x5, x6, [x0, #8] + ldp x7, x8, [x0, #24] + + cmp x8, #4 + beq non_linear_addc_i32 + + {% for col in range(8, 16) %} + mov x4, x5 + {% for reg in range(0, 2) %} + {% for lane in range(0, 4) %} + ld1 {v0.b}[{{lane}}], [ x4 ], x6 + {% endfor %} + sshll v0.8h, v0.8b, 0 + sshll v0.4s, v0.4h, 0 + add v{{ col * 2 + reg }}.4s, v{{ col * 2 + reg }}.4s, v0.4s + {% endfor %} + add x5, x5, x7 + {% endfor %} + + b .non_linear_loop + +non_linear_addc_i32: + {% for col in range(8, 16) %} + mov x4, x5 + {% for reg in range(0, 2) %} + {% for lane in range(0, 4) %} + ld1 {v0.s}[{{lane}}], [ x4 ], x6 + {% endfor %} + add v{{ col * 2 + reg }}.4s, v{{ col * 2 + reg }}.4s, v0.4s + {% endfor %} + add x5, x5, x7 + {% endfor %} + + b .non_linear_loop + +.add_row_col_products: + ldr x2, [x0, #8] + ldr x3, [x0, #16] + + ld1 { v0.4s, v1.4s }, [ x2 ] + ld1 { v4.4s, v5.4s }, [ x3 ] + + xtn v0.4h, v0.4s + xtn v1.4h, v1.4s + xtn v4.4h, v4.4s + xtn v5.4h, v5.4s + + smlal v16.4s, v0.4h, v4.h[0] + smlal v17.4s, v1.4h, v4.h[0] + smlal v18.4s, v0.4h, v4.h[1] + smlal v19.4s, v1.4h, v4.h[1] + smlal v20.4s, v0.4h, v4.h[2] + smlal v21.4s, v1.4h, v4.h[2] + smlal v22.4s, v0.4h, v4.h[3] + smlal v23.4s, v1.4h, v4.h[3] + + smlal v24.4s, v0.4h, v5.h[0] + smlal v25.4s, v1.4h, v5.h[0] + smlal v26.4s, v0.4h, v5.h[1] + smlal v27.4s, v1.4h, v5.h[1] + smlal v28.4s, v0.4h, v5.h[2] + smlal v29.4s, v1.4h, v5.h[2] + smlal v30.4s, v0.4h, v5.h[3] + smlal v31.4s, v1.4h, v5.h[3] + + b .non_linear_loop + + {% include "arm64simd_mmm_i32_scale_q16_q31.j2" %} + +.store: + ldp x5, x6, [x0, #8] // c base ptr, rsc + ldp x7, x8, [x0, #24] // csc, item_size + + cmp x8, #4 + beq .store_strides_i32 + + {% for col in range(8, 16) %} + mov x4, x5 + {% for reg in range(0, 2) %} + {% for lane in range(0, 4) %} + st1 { v{{ col * 2 + reg }}.b }[{{ lane * 4 }}], [ x4 ], x6 + {% endfor %} + {% endfor %} + add x5, x5, x7 + {% endfor %} + + b .non_linear_loop + +.store_strides_i32: + {% for col in range(8, 16) %} + mov x4, x5 + {% for reg in range(0, 2) %} + {% for lane in range(0, 4) %} + st1 { v{{ col * 2 + reg }}.s }[{{lane}}], [ x4 ], x6 + {% endfor %} + {% endfor %} + add x5, x5, x7 + {% endfor %} + + b .non_linear_loop + +.return: + ldp d14, d15, [sp], #16 + ldp d12, d13, [sp], #16 + ldp d10, d11, [sp], #16 + ldp d8, d9, [sp], #16 + + ldp x26, x27, [sp], #16 + ldp x24, x25, [sp], #16 + ldp x22, x23, [sp], #16 + ldp x20, x21, [sp], #16 + + ret + diff --git a/linalg/arm64/arm64simd/dummy_dotprod.S b/linalg/arm64/arm64simd/dummy_dotprod.S new file mode 100644 index 0000000000..4304549674 --- /dev/null +++ b/linalg/arm64/arm64simd/dummy_dotprod.S @@ -0,0 +1,13 @@ +// Build-time capability probe for the assembler, used by build.rs +// (assembler_supports_dotprod). Older binutils — notably the Debian stretch +// aarch64 cross-toolchain in CI — predate FEAT_DotProd and cannot assemble +// `sdot` even with `.cpu generic+fp+simd+dotprod`. If this file fails to +// assemble, build.rs skips the SDOT kernel and the `tract_arm64_dotprod` cfg, +// and the runtime falls back to the SMLAL 8x8 i32 kernel. Not linked into +// anything. +.cpu generic+fp+simd+dotprod +.text +.globl tract_dotprod_probe +tract_dotprod_probe: + sdot v0.4s, v1.16b, v2.4b[0] + ret diff --git a/linalg/arm64/sme/sme_qmmm_i32_32x32.S.j2 b/linalg/arm64/sme/sme_qmmm_i32_32x32.S.j2 new file mode 100644 index 0000000000..7c7e2f5b75 --- /dev/null +++ b/linalg/arm64/sme/sme_qmmm_i32_32x32.S.j2 @@ -0,0 +1,681 @@ +// vim: ft=arm +// +// SME2 i32 32x32 quantized matmul kernel. +// +// ZA tile layout (4 .S tiles, 16x16 i32 each): +// ZA0.S : C[0..16, 0..16] (top-left) +// ZA1.S : C[0..16, 16..32] (top-right) +// ZA2.S : C[16..32, 0..16] (bottom-left) +// ZA3.S : C[16..32, 16..32] (bottom-right) +// +// Inner K-step (K decrements by 4 per iter, since SMOPA at i8 reduces 4): +// ld1b {z0, z1}, pn8/z, [A] ; 32 M × 4 K = 128 i8 of A +// ld1b {z2, z3}, pn8/z, [B] ; 32 N × 4 K = 128 i8 of B +// smopa za0.s, p0/m, p0/m, z0.b, z2.b ; ZA0 += A[0..16] × B[0..16] +// smopa za1.s, p0/m, p0/m, z0.b, z3.b +// smopa za2.s, p0/m, p0/m, z1.b, z2.b +// smopa za3.s, p0/m, p0/m, z1.b, z3.b +// +// SMOPA at i8 throughput: 4-way K reduction per insn × 16x16 cells = 1024 +// MACs per insn. With 4-tile rotation we approach 4 SMOPAs/cycle = 4 K +// reduction × 16x16 = 4096 MACs/cycle ≈ ~16 TOPS theoretical peak. +// +// Calling convention (extern "C", AAPCS64): +// x0 = const *FusedKerSpec, advanced 40 B per dispatcher iteration +// x1 = 4 KiB scratch buffer for tile spills (used by store-generic / q_scale) +// +// Tract packing requirement: i8 inputs packed with K_alignment=4 (SMOPA +// requires K%4=0). The PackedFormat::with_k_alignment(4) handles this. + +.arch armv9-a+sme2 +.text +.align 4 + +.global {{G}}sme_qmmm_i32_32x32_{{suffix}} +{{G}}sme_qmmm_i32_32x32_{{suffix}}: + + stp q8, q9, [sp, #-128]! + stp q10, q11, [sp, #32] + stp q12, q13, [sp, #64] + stp q14, q15, [sp, #96] + + sub sp, sp, #4096 + mov x1, sp + + smstart + ptrue p0.b + ptrue pn8.b + mov w8, #0 + +{% include "dispatcher.j2" %} + +// -------- AddMatMul: ZA += A·B at i8 with K=4 reduction per SMOPA ---------- + +.add_mat_mul: + ldr x9, [x0, #32] // packing index + ldr x2, [x0, #24] // b ptr + ldp x3, x4, [x0, #8] // k, a ptr + cmp x3, #0 + b.eq .non_linear_loop + cmp x9, #1 + b.eq .Lmatmul_loop +// i32i32 fallback (packing != 1, auto-test path): ZA += A[:,k] (x) B[k,:], one +// K-step at a time via predicated MLA rank-1 updates. One instruction per line: +// the Apple/LLVM AArch64 assembler treats `;` as a COMMENT, so semicolon-packed +// statements silently drop everything after the first `;`. +.Lk32: + ld1w {z2.s}, p0/z, [x2] // B[k, 0..16] + ld1w {z3.s}, p0/z, [x2, #1, mul vl] // B[k, 16..32] + mov w12, #0 +.Lkt: + ldr w10, [x4, w12, uxtw #2] // A[k, w12] + dup z4.s, w10 + mov z16.s, p0/m, za0h.s[w12, 0] + mov z17.s, p0/m, za1h.s[w12, 0] + mla z16.s, p0/m, z2.s, z4.s // C[w12, 0..16] += A[w12] * B[0..16] + mla z17.s, p0/m, z3.s, z4.s // C[w12, 16..32] += A[w12] * B[16..32] + mov za0h.s[w12, 0], p0/m, z16.s + mov za1h.s[w12, 0], p0/m, z17.s + add w10, w12, #16 + ldr w10, [x4, w10, uxtw #2] // A[k, w12+16] + dup z4.s, w10 + mov z18.s, p0/m, za2h.s[w12, 0] + mov z19.s, p0/m, za3h.s[w12, 0] + mla z18.s, p0/m, z2.s, z4.s // C[w12+16, 0..16] += A[w12+16] * B[0..16] + mla z19.s, p0/m, z3.s, z4.s // C[w12+16, 16..32] += A[w12+16] * B[16..32] + mov za2h.s[w12, 0], p0/m, z18.s + mov za3h.s[w12, 0], p0/m, z19.s + add w12, w12, #1 + cmp w12, #16 + b.lt .Lkt + add x4, x4, #128 + add x2, x2, #128 + subs x3, x3, #1 + b.ne .Lk32 + b .non_linear_loop + +.Lmatmul_loop: + ld1b {z0.b, z1.b}, pn8/z, [x4] + ld1b {z2.b, z3.b}, pn8/z, [x2] + add x4, x4, #128 + add x2, x2, #128 + smopa za0.s, p0/m, p0/m, z0.b, z2.b + smopa za1.s, p0/m, p0/m, z0.b, z3.b + smopa za2.s, p0/m, p0/m, z1.b, z2.b + smopa za3.s, p0/m, p0/m, z1.b, z3.b + subs x3, x3, #4 + b.gt .Lmatmul_loop + b .non_linear_loop + +.clear: + zero {za} + b .non_linear_loop + +// -------- Store: i32 tile -> memory (port of Phase 1 f32 store) ----------- + +.store: + ldp x5, x6, [x0, #8] // ptr, row_byte_stride + ldp x7, x9, [x0, #24] // col_byte_stride, item_size + + cmp x7, #4 + b.ne .Lstore_generic + cmp x9, #4 + b.ne .Lstore_generic + + add x4, x5, #64 + mov w12, #0 +.Lstore_top: + st1w {za0h.s[w12, 0]}, p0, [x5] + st1w {za1h.s[w12, 0]}, p0, [x4] + add x5, x5, x6 + add x4, x4, x6 + add w12, w12, #1 + cmp w12, #16 + b.lt .Lstore_top + mov w12, #0 +.Lstore_bot: + st1w {za2h.s[w12, 0]}, p0, [x5] + st1w {za3h.s[w12, 0]}, p0, [x4] + add x5, x5, x6 + add x4, x4, x6 + add w12, w12, #1 + cmp w12, #16 + b.lt .Lstore_bot + b .non_linear_loop + +.Lstore_generic: + mov x13, x9 // preserve item_size before x9 is reused as a ptr + mov x4, x1 + add x9, x1, #64 + mov w12, #0 +.Lstore_spill_top: + st1w {za0h.s[w12, 0]}, p0, [x4] + st1w {za1h.s[w12, 0]}, p0, [x9] + add x4, x4, #128 + add x9, x9, #128 + add w12, w12, #1 + cmp w12, #16 + b.lt .Lstore_spill_top + mov w12, #0 +.Lstore_spill_bot: + st1w {za2h.s[w12, 0]}, p0, [x4] + st1w {za3h.s[w12, 0]}, p0, [x9] + add x4, x4, #128 + add x9, x9, #128 + add w12, w12, #1 + cmp w12, #16 + b.lt .Lstore_spill_bot + + mov x3, #0 +.Lstore_row: + mov x4, x5 + mov x10, #0 + lsl x9, x3, #7 + add x11, x1, x9 +.Lstore_col: + ldr w9, [x11], #4 + cmp x13, #1 // item_size: 1 -> strb, 2 -> strh, else (4) -> str + b.eq .Lstore_b1 + cmp x13, #2 + b.eq .Lstore_b2 + str w9, [x4] + b .Lstore_cnext +.Lstore_b1: + strb w9, [x4] + b .Lstore_cnext +.Lstore_b2: + strh w9, [x4] +.Lstore_cnext: + add x4, x4, x7 + add x10, x10, #1 + cmp x10, #32 + b.lt .Lstore_col + add x5, x5, x6 + add x3, x3, #1 + cmp x3, #32 + b.lt .Lstore_row + b .non_linear_loop + +// -------- LoadTile: ZA := row-major i32 tile from memory ------------------- + +.load_tile: + ldr x2, [x0, #16] + add x4, x2, #64 + mov w12, #0 +.Lloadtile_top: + ld1w {z6.s}, p0/z, [x2] + ld1w {z7.s}, p0/z, [x4] + mov za0h.s[w12, 0], p0/m, z6.s + mov za1h.s[w12, 0], p0/m, z7.s + add x2, x2, #128 + add x4, x4, #128 + add w12, w12, #1 + cmp w12, #16 + b.lt .Lloadtile_top + mov w12, #0 +.Lloadtile_bot: + ld1w {z6.s}, p0/z, [x2] + ld1w {z7.s}, p0/z, [x4] + mov za2h.s[w12, 0], p0/m, z6.s + mov za3h.s[w12, 0], p0/m, z7.s + add x2, x2, #128 + add x4, x4, #128 + add w12, w12, #1 + cmp w12, #16 + b.lt .Lloadtile_bot + b .non_linear_loop + +// -------- AddUnicast: ZA += C (strided load + add) ------------------------ + +.add_unicast: + ldp x5, x6, [x0, #8] // ptr, row_byte_stride + ldp x7, x9, [x0, #24] // col_byte_stride, item_size + + cmp x7, #4 + b.ne .Laddu_generic + cmp x9, #4 + b.ne .Laddu_generic + + add x4, x5, #64 + mov w12, #0 +.Laddu_top: + ld1w {z8.s}, p0/z, [x5] + ld1w {z9.s}, p0/z, [x4] + mov z6.s, p0/m, za0h.s[w12, 0] + mov z7.s, p0/m, za1h.s[w12, 0] + add z6.s, p0/m, z6.s, z8.s + add z7.s, p0/m, z7.s, z9.s + mov za0h.s[w12, 0], p0/m, z6.s + mov za1h.s[w12, 0], p0/m, z7.s + add x5, x5, x6 + add x4, x4, x6 + add w12, w12, #1 + cmp w12, #16 + b.lt .Laddu_top + mov w12, #0 +.Laddu_bot: + ld1w {z8.s}, p0/z, [x5] + ld1w {z9.s}, p0/z, [x4] + mov z6.s, p0/m, za2h.s[w12, 0] + mov z7.s, p0/m, za3h.s[w12, 0] + add z6.s, p0/m, z6.s, z8.s + add z7.s, p0/m, z7.s, z9.s + mov za2h.s[w12, 0], p0/m, z6.s + mov za3h.s[w12, 0], p0/m, z7.s + add x5, x5, x6 + add x4, x4, x6 + add w12, w12, #1 + cmp w12, #16 + b.lt .Laddu_bot + b .non_linear_loop + +.Laddu_generic: + // Strided gather to scratch, then contig accumulate (mirrors Phase 1). + mov x3, #0 + mov x10, x1 +.Laddu_gather_row: + mov x11, x5 + mov x4, #0 +.Laddu_gather_col: + ldr w9, [x11] + str w9, [x10], #4 + add x11, x11, x7 + add x4, x4, #1 + cmp x4, #32 + b.lt .Laddu_gather_col + add x5, x5, x6 + add x3, x3, #1 + cmp x3, #32 + b.lt .Laddu_gather_row + + mov x4, x1 + add x9, x1, #64 + mov w12, #0 +.Laddu_apply_top: + ld1w {z8.s}, p0/z, [x4] + ld1w {z10.s}, p0/z, [x9] + mov z6.s, p0/m, za0h.s[w12, 0] + mov z7.s, p0/m, za1h.s[w12, 0] + add z6.s, p0/m, z6.s, z8.s + add z7.s, p0/m, z7.s, z10.s + mov za0h.s[w12, 0], p0/m, z6.s + mov za1h.s[w12, 0], p0/m, z7.s + add x4, x4, #128 + add x9, x9, #128 + add w12, w12, #1 + cmp w12, #16 + b.lt .Laddu_apply_top + mov w12, #0 +.Laddu_apply_bot: + ld1w {z8.s}, p0/z, [x4] + ld1w {z10.s}, p0/z, [x9] + mov z6.s, p0/m, za2h.s[w12, 0] + mov z7.s, p0/m, za3h.s[w12, 0] + add z6.s, p0/m, z6.s, z8.s + add z7.s, p0/m, z7.s, z10.s + mov za2h.s[w12, 0], p0/m, z6.s + mov za3h.s[w12, 0], p0/m, z7.s + add x4, x4, #128 + add x9, x9, #128 + add w12, w12, #1 + cmp w12, #16 + b.lt .Laddu_apply_bot + b .non_linear_loop + +// -------- AddRowColProducts: ZA += rows ⊗ cols (i32 outer product) -------- +// +// rows: 32 i32 (broadcast per M-row), cols: 32 i32 (lane vector per N-col). +// Per ZA row, we need: ZA[i, j] += rows[i] * cols[j]. Slice-by-slice. + +.add_row_col_products: + ldp x2, x3, [x0, #8] // rows ptr, cols ptr + ld1w {z4.s}, p0/z, [x3] // cols[0..16] + ld1w {z5.s}, p0/z, [x3, #1, mul vl] // cols[16..32] + + // Top 16 rows + mov w12, #0 +.Larcp_top: + ldr w9, [x2], #4 + dup z16.s, w9 // broadcast rows[i] to z16 + mov z6.s, p0/m, za0h.s[w12, 0] + mov z7.s, p0/m, za1h.s[w12, 0] + mla z6.s, p0/m, z16.s, z4.s // z6 += z16 * cols[0..16] + mla z7.s, p0/m, z16.s, z5.s // z7 += z16 * cols[16..32] + mov za0h.s[w12, 0], p0/m, z6.s + mov za1h.s[w12, 0], p0/m, z7.s + add w12, w12, #1 + cmp w12, #16 + b.lt .Larcp_top + // Bottom 16 rows + mov w12, #0 +.Larcp_bot: + ldr w9, [x2], #4 + dup z16.s, w9 + mov z6.s, p0/m, za2h.s[w12, 0] + mov z7.s, p0/m, za3h.s[w12, 0] + mla z6.s, p0/m, z16.s, z4.s + mla z7.s, p0/m, z16.s, z5.s + mov za2h.s[w12, 0], p0/m, z6.s + mov za3h.s[w12, 0], p0/m, z7.s + add w12, w12, #1 + cmp w12, #16 + b.lt .Larcp_bot + b .non_linear_loop + +// -------- scalar fuse ops: broadcast scalar, apply lane-wise -------------- +// +// Sub vs SubF (matches Phase 1's f32 convention): +// ScalarSub → result = scalar - z (mnemonic: subr) +// ScalarSubF → result = z - scalar (mnemonic: sub) + +{% macro scalar_op_i32(label, op) %} +{{label}}: + ldr w2, [x0, #8] + dup z4.s, w2 + mov w12, #0 +.L{{label|replace('.', '')}}_top: + mov z6.s, p0/m, za0h.s[w12, 0] + mov z7.s, p0/m, za1h.s[w12, 0] + {{op}} z6.s, p0/m, z6.s, z4.s + {{op}} z7.s, p0/m, z7.s, z4.s + mov za0h.s[w12, 0], p0/m, z6.s + mov za1h.s[w12, 0], p0/m, z7.s + add w12, w12, #1 + cmp w12, #16 + b.lt .L{{label|replace('.', '')}}_top + mov w12, #0 +.L{{label|replace('.', '')}}_bot: + mov z6.s, p0/m, za2h.s[w12, 0] + mov z7.s, p0/m, za3h.s[w12, 0] + {{op}} z6.s, p0/m, z6.s, z4.s + {{op}} z7.s, p0/m, z7.s, z4.s + mov za2h.s[w12, 0], p0/m, z6.s + mov za3h.s[w12, 0], p0/m, z7.s + add w12, w12, #1 + cmp w12, #16 + b.lt .L{{label|replace('.', '')}}_bot + b .non_linear_loop +{% endmacro %} + +{{ scalar_op_i32('.scalar_add', 'add') }} +{{ scalar_op_i32('.scalar_mul', 'mul') }} +{{ scalar_op_i32('.scalar_sub', 'subr') }} +{{ scalar_op_i32('.scalar_sub_flipped', 'sub') }} +{{ scalar_op_i32('.scalar_min', 'smin') }} +{{ scalar_op_i32('.scalar_max', 'smax') }} + +// -------- per_col fuse ops: 32-elem vector, broadcast across M rows ------ + +{% macro per_col_op_i32(label, op) %} +{{label}}: + ldr x2, [x0, #8] + ld1w {z4.s}, p0/z, [x2] + ld1w {z5.s}, p0/z, [x2, #1, mul vl] + mov w12, #0 +.L{{label|replace('.', '')}}_top: + mov z6.s, p0/m, za0h.s[w12, 0] + mov z7.s, p0/m, za1h.s[w12, 0] + {{op}} z6.s, p0/m, z6.s, z4.s + {{op}} z7.s, p0/m, z7.s, z5.s + mov za0h.s[w12, 0], p0/m, z6.s + mov za1h.s[w12, 0], p0/m, z7.s + add w12, w12, #1 + cmp w12, #16 + b.lt .L{{label|replace('.', '')}}_top + mov w12, #0 +.L{{label|replace('.', '')}}_bot: + mov z6.s, p0/m, za2h.s[w12, 0] + mov z7.s, p0/m, za3h.s[w12, 0] + {{op}} z6.s, p0/m, z6.s, z4.s + {{op}} z7.s, p0/m, z7.s, z5.s + mov za2h.s[w12, 0], p0/m, z6.s + mov za3h.s[w12, 0], p0/m, z7.s + add w12, w12, #1 + cmp w12, #16 + b.lt .L{{label|replace('.', '')}}_bot + b .non_linear_loop +{% endmacro %} + +{{ per_col_op_i32('.per_col_add', 'add') }} +{{ per_col_op_i32('.per_col_mul', 'mul') }} +{{ per_col_op_i32('.per_col_sub', 'subr') }} +{{ per_col_op_i32('.per_col_sub_flipped', 'sub') }} +{{ per_col_op_i32('.per_col_min', 'smin') }} +{{ per_col_op_i32('.per_col_max', 'smax') }} + +// -------- per_row fuse ops: 32-elem vector, one scalar per M row --------- + +{% macro per_row_op_i32(label, op) %} +{{label}}: + ldr x2, [x0, #8] + add x3, x2, #64 + mov w12, #0 +.L{{label|replace('.', '')}}_top: + ldr w4, [x2], #4 + dup z4.s, w4 + mov z6.s, p0/m, za0h.s[w12, 0] + mov z7.s, p0/m, za1h.s[w12, 0] + {{op}} z6.s, p0/m, z6.s, z4.s + {{op}} z7.s, p0/m, z7.s, z4.s + mov za0h.s[w12, 0], p0/m, z6.s + mov za1h.s[w12, 0], p0/m, z7.s + add w12, w12, #1 + cmp w12, #16 + b.lt .L{{label|replace('.', '')}}_top + mov w12, #0 +.L{{label|replace('.', '')}}_bot: + ldr w4, [x3], #4 + dup z4.s, w4 + mov z6.s, p0/m, za2h.s[w12, 0] + mov z7.s, p0/m, za3h.s[w12, 0] + {{op}} z6.s, p0/m, z6.s, z4.s + {{op}} z7.s, p0/m, z7.s, z4.s + mov za2h.s[w12, 0], p0/m, z6.s + mov za3h.s[w12, 0], p0/m, z7.s + add w12, w12, #1 + cmp w12, #16 + b.lt .L{{label|replace('.', '')}}_bot + b .non_linear_loop +{% endmacro %} + +{{ per_row_op_i32('.per_row_add', 'add') }} +{{ per_row_op_i32('.per_row_mul', 'mul') }} +{{ per_row_op_i32('.per_row_sub', 'subr') }} +{{ per_row_op_i32('.per_row_sub_flipped', 'sub') }} +{{ per_row_op_i32('.per_row_min', 'smin') }} +{{ per_row_op_i32('.per_row_max', 'smax') }} + +// -------- Quantization fuse ops (bit-exact port of generic/rounding.rs) ---- +// +// Strategy: spill the 32x32 i32 ZA tile to the 4 KiB scratch (x1), quantize +// element-wise in SCALAR GP registers (streaming-mode legal: smull/lsr/asr/ +// cneg/cset/... are base A64 and unaffected by PSTATE.SM), then reload to ZA. +// Quant is not the hot path; this mirrors the scalar approach already proven +// in arm64/sve/sve_mmm_i32.c. Everything is inlined (no `bl` — a nested call +// would clobber x30 and corrupt the final `ret`). +// +// Bit-exactness: the reference forms the FULL i64 product (mult*v) and does a +// single magnitude-rounding shift by (shift+31) with a per-policy nudge. A +// vector sqdmulh+srshl truncates the low 31 bits before the second shift, so +// it is NOT equivalent — hence the i64 scalar port. +// +// RoundingPolicy: Native=0 Zero=1 Away=2 MinusInf=3 PlusInf=4 Even=5 Odd=6. + +// Spill ZA0..ZA3 -> scratch[x1] as a contiguous 32x32 row-major i32 matrix +// (same layout the generic store path uses). Clobbers x4, x9, w12. +{% macro za_spill(sfx) %} + mov x4, x1 + add x9, x1, #64 + mov w12, #0 +.Lspt_{{sfx}}: + st1w {za0h.s[w12, 0]}, p0, [x4] + st1w {za1h.s[w12, 0]}, p0, [x9] + add x4, x4, #128 + add x9, x9, #128 + add w12, w12, #1 + cmp w12, #16 + b.lt .Lspt_{{sfx}} + mov w12, #0 +.Lspb_{{sfx}}: + st1w {za2h.s[w12, 0]}, p0, [x4] + st1w {za3h.s[w12, 0]}, p0, [x9] + add x4, x4, #128 + add x9, x9, #128 + add w12, w12, #1 + cmp w12, #16 + b.lt .Lspb_{{sfx}} +{% endmacro %} + +// Reload scratch[x1] (32x32 row-major i32) -> ZA0..ZA3. Clobbers x4,x9,w12,z6,z7. +{% macro za_reload(sfx) %} + mov x4, x1 + add x9, x1, #64 + mov w12, #0 +.Lrlt_{{sfx}}: + ld1w {z6.s}, p0/z, [x4] + ld1w {z7.s}, p0/z, [x9] + mov za0h.s[w12, 0], p0/m, z6.s + mov za1h.s[w12, 0], p0/m, z7.s + add x4, x4, #128 + add x9, x9, #128 + add w12, w12, #1 + cmp w12, #16 + b.lt .Lrlt_{{sfx}} + mov w12, #0 +.Lrlb_{{sfx}}: + ld1w {z6.s}, p0/z, [x4] + ld1w {z7.s}, p0/z, [x9] + mov za2h.s[w12, 0], p0/m, z6.s + mov za3h.s[w12, 0], p0/m, z7.s + add x4, x4, #128 + add x9, x9, #128 + add w12, w12, #1 + cmp w12, #16 + b.lt .Lrlb_{{sfx}} +{% endmacro %} + +// Magnitude-rounding shared by q_scale and q_shr (mirrors `Mul for Scaler` +// / `i32::q_shr`). In: x12 = val (i64), x5 = shift, x6 = policy. Out: w14 (i32). +// Clobbers x13,x15,x16,x17. Preserves x5,x6,x7,x10,x11,x12. +{% macro round_mag(sfx) %} + cmp x5, #0 + b.gt .Lrpos_{{sfx}} + neg x13, x5 + lsl x13, x12, x13 // val << (-shift) + mov w14, w13 + b .Lrend_{{sfx}} +.Lrpos_{{sfx}}: + cmp x12, #0 + cneg x15, x12, mi // x15 = |val| + sub x13, x5, #1 + mov x16, #1 + lsl x16, x16, x13 // x16 = half = 1 << (shift-1) + cmp x6, #2 // Away -> nudge 0 + b.eq .Lrn0_{{sfx}} + cmp x6, #1 // Zero -> nudge -1 + b.ne .Lrna_{{sfx}} + mov x17, #-1 + b .Lrnd_{{sfx}} +.Lrna_{{sfx}}: + cmp x6, #3 // MinusInf -> -(val >= 0) + b.ne .Lrnb_{{sfx}} + cmp x12, #0 + cset x17, ge + neg x17, x17 + b .Lrnd_{{sfx}} +.Lrnb_{{sfx}}: + cmp x6, #4 // PlusInf -> -(val <= 0) + b.ne .Lrnc_{{sfx}} + cmp x12, #0 + cset x17, le + neg x17, x17 + b .Lrnd_{{sfx}} +.Lrnc_{{sfx}}: + cmp x6, #5 // Even -> ((|val|>>shift)&1) - 1 + b.ne .Lrno_{{sfx}} + lsr x17, x15, x5 + and x17, x17, #1 + sub x17, x17, #1 + b .Lrnd_{{sfx}} +.Lrno_{{sfx}}: // Odd -> -((|val|>>shift)&1) + lsr x17, x15, x5 + and x17, x17, #1 + neg x17, x17 + b .Lrnd_{{sfx}} +.Lrn0_{{sfx}}: + mov x17, #0 +.Lrnd_{{sfx}}: + add x15, x15, x16 + add x15, x15, x17 + lsr x15, x15, x5 // (|val| + half + nudge) >> shift + cmp x12, #0 + cneg x14, x15, mi // signum(val) * mag +.Lrend_{{sfx}}: +{% endmacro %} + +// QScale(shift, policy, mult): val = mult*v (i64); shift += 31; magnitude round. +.q_scale: + ldr x5, [x0, #8] // shift (isize) + ldr x6, [x0, #16] // policy + ldr w7, [x0, #24] // mult (i32) + add x5, x5, #31 + {{ za_spill('qsc') }} + mov x10, x1 + mov x11, #1024 +.Lqsc_loop: + ldr w9, [x10] + smull x12, w7, w9 // val = (i64)mult * (i64)v + {{ round_mag('qsc') }} + str w14, [x10], #4 + subs x11, x11, #1 + b.ne .Lqsc_loop + {{ za_reload('qsc') }} + b .non_linear_loop + +// RoundingShiftRight(shift, policy): val = v (i64); magnitude round (shift>0). +.q_shr: + ldr x5, [x0, #8] // shift (usize, >= 1) + ldr x6, [x0, #16] // policy + {{ za_spill('qsr') }} + mov x10, x1 + mov x11, #1024 +.Lqsr_loop: + ldr w9, [x10] + sxtw x12, w9 // val = (i64)v + {{ round_mag('qsr') }} + str w14, [x10], #4 + subs x11, x11, #1 + b.ne .Lqsr_loop + {{ za_reload('qsr') }} + b .non_linear_loop + +// ShiftLeft(shift): result = v << shift (32-bit wrapping, matches i32::q_shl). +.q_shl: + ldr x5, [x0, #8] // shift (usize) + {{ za_spill('qsl') }} + mov x10, x1 + mov x11, #1024 +.Lqsl_loop: + ldr w9, [x10] + lsl w9, w9, w5 + str w9, [x10], #4 + subs x11, x11, #1 + b.ne .Lqsl_loop + {{ za_reload('qsl') }} + b .non_linear_loop + +// -------- LeakyRelu (excluded via CAN_FUSE_I32) --------------------------- + +.leaky_relu: + b .unsupported + +// -------- epilogue -------------------------------------------------------- + +.return: + smstop + add sp, sp, #4096 + ldp q14, q15, [sp, #96] + ldp q12, q13, [sp, #64] + ldp q10, q11, [sp, #32] + ldp q8, q9, [sp], #128 + ret diff --git a/linalg/benches/activations_avx512.rs b/linalg/benches/activations_avx512.rs new file mode 100644 index 0000000000..f29ca73152 --- /dev/null +++ b/linalg/benches/activations_avx512.rs @@ -0,0 +1,108 @@ +// Microbenchmark: AVX-512 (zmm, 16-wide) element-wise activation kernels vs +// their x86 predecessor. +// +// sigmoid, tanh : predecessor = FMA (256-bit, 8-wide) kernel +// hardswish, leaky_relu, +// silu, gelu : predecessor = generic scalar kernel +// (no FMA kernel exists on x86) +// +// All buffers are 64-byte aligned (AVX-512 alignment_bytes) and a multiple of +// 64 elements so every kernel's nr() divides the length. Criterion reports the +// distribution; the min of the samples is the relevant "min-of-N" number. + +use criterion::*; +use tract_data::prelude::*; +use tract_linalg::element_wise::ElementWiseKer; + +const N: usize = 1024; + +fn aligned_input() -> Tensor { + let mut t = unsafe { Tensor::uninitialized_aligned::(&[N], 64).unwrap() }; + let s = unsafe { t.as_slice_mut_unchecked::() }; + for (i, x) in s.iter_mut().enumerate() { + *x = (i as f32 / 10.0).sin() * 5.0; + } + t +} + +// Enable FTZ/DAZ (flush-to-zero, denormals-are-zero) for the whole process so +// repeated in-place application of a kernel to its own output cannot collapse +// into denormal arithmetic (extremely slow on x86) and distort the timing. +// Mirrors what the sigmoid/tanh kernels already do internally via MXCSR. +#[cfg(target_arch = "x86_64")] +fn enable_ftz_daz() { + // Set MXCSR bit 15 (FTZ) and bit 6 (DAZ) directly; the safe intrinsic + // wrappers are deprecated in favour of inline asm. + unsafe { + let mut mxcsr: u32 = 0; + std::arch::asm!("stmxcsr [{p}]", p = in(reg) &mut mxcsr); + mxcsr |= (1 << 15) | (1 << 6); + std::arch::asm!("ldmxcsr [{p}]", p = in(reg) &mxcsr); + } +} + +// In-place throughput, matching the convention of the existing element-wise +// benches (sigmoid.rs / silu.rs). +macro_rules! bench_pair { + ($c:expr, $name:expr, $pred_label:expr, $pred:ty, $avx512:ty $(, $param:expr)?) => {{ + let mut group = $c.benchmark_group($name); + group.throughput(Throughput::Elements(N as u64)); + let mut tp = aligned_input(); + let sp = unsafe { tp.as_slice_mut_unchecked::() }; + group.bench_function($pred_label, |b| { + b.iter(|| <$pred>::run(sp, ($($param)?))) + }); + if std::is_x86_feature_detected!("avx512f") { + let mut ta = aligned_input(); + let sa = unsafe { ta.as_slice_mut_unchecked::() }; + group.bench_function("avx512", |b| { + b.iter(|| <$avx512>::run(sa, ($($param)?))) + }); + } + group.finish(); + }}; +} + +fn benches(c: &mut Criterion) { + #[cfg(target_arch = "x86_64")] + enable_ftz_daz(); + use tract_linalg::x86_64_fma::act::*; + use tract_linalg::x86_64_fma::{ + avx512_sigmoid_f32, avx512_tanh_f32, fma_sigmoid_f32, fma_tanh_f32, + }; + + bench_pair!(c, "sigmoid_f32", "fma", fma_sigmoid_f32, avx512_sigmoid_f32); + bench_pair!(c, "tanh_f32", "fma", fma_tanh_f32, avx512_tanh_f32); + bench_pair!( + c, + "hardswish_f32", + "generic", + tract_linalg::generic::SHardSwish4, + x86_64_avx512_hardswish_f32_64n + ); + bench_pair!( + c, + "leaky_relu_f32", + "generic", + tract_linalg::generic::SLeakyRelu4, + x86_64_avx512_leaky_relu_f32_64n, + 0.1f32 + ); + bench_pair!( + c, + "silu_f32", + "generic", + tract_linalg::generic::SSiLU4, + x86_64_avx512_silu_f32_16n + ); + bench_pair!( + c, + "gelu_f32", + "generic", + tract_linalg::generic::SGelu4, + x86_64_avx512_gelu_f32_16n + ); +} + +criterion_group!(g, benches); +criterion_main!(g); diff --git a/linalg/benches/activations_avx512_f16.rs b/linalg/benches/activations_avx512_f16.rs new file mode 100644 index 0000000000..a73a735477 --- /dev/null +++ b/linalg/benches/activations_avx512_f16.rs @@ -0,0 +1,83 @@ +// Microbenchmark: AVX-512 f16 element-wise activations vs the generic scalar +// f16 kernels (no FMA f16 predecessor exists on x86 — the generic baseline +// already runs `(*v - max).to_f32()`-style per-element conversions). Buffers +// are 64-byte aligned (alignment that the AVX-512 path uses internally) and +// a multiple of 64 elements. + +use criterion::*; +use tract_data::prelude::*; +use tract_linalg::element_wise::ElementWiseKer; + +const N: usize = 1024; + +fn aligned_input() -> Tensor { + let mut t = unsafe { Tensor::uninitialized_aligned::(&[N], 64).unwrap() }; + let s = unsafe { t.as_slice_mut_unchecked::() }; + for (i, x) in s.iter_mut().enumerate() { + *x = f16::from_f32((i as f32 / 10.0).sin() * 5.0); + } + t +} + +macro_rules! bench_pair { + ($c:expr, $name:expr, $pred:ty, $avx512:ty $(, $param:expr)?) => {{ + let mut group = $c.benchmark_group($name); + group.throughput(Throughput::Elements(N as u64)); + let mut tp = aligned_input(); + let sp = unsafe { tp.as_slice_mut_unchecked::() }; + group.bench_function("generic", |b| { + b.iter(|| <$pred>::run(sp, ($($param)?))) + }); + if std::is_x86_feature_detected!("avx512f") { + let mut ta = aligned_input(); + let sa = unsafe { ta.as_slice_mut_unchecked::() }; + group.bench_function("avx512", |b| { + b.iter(|| <$avx512>::run(sa, ($($param)?))) + }); + } + group.finish(); + }}; +} + +fn benches(c: &mut Criterion) { + bench_pair!( + c, + "sigmoid_f16", + tract_linalg::generic::sigmoid::HSigmoid8, + tract_linalg::x86_64_fma::act_f16::x86_64_avx512_sigmoid_f16_16n + ); + bench_pair!( + c, + "tanh_f16", + tract_linalg::generic::tanh::HTanh8, + tract_linalg::x86_64_fma::act_f16::x86_64_avx512_tanh_f16_16n + ); + bench_pair!( + c, + "hardswish_f16", + tract_linalg::generic::hardswish::HHardSwish8, + tract_linalg::x86_64_fma::act_f16::x86_64_avx512_hardswish_f16_64n + ); + bench_pair!( + c, + "leaky_relu_f16", + tract_linalg::generic::leaky_relu::HLeakyRelu8, + tract_linalg::x86_64_fma::act_f16::x86_64_avx512_leaky_relu_f16_64n, + f16::from_f32(0.1) + ); + bench_pair!( + c, + "silu_f16", + tract_linalg::generic::silu::HSiLU8, + tract_linalg::x86_64_fma::act_f16::x86_64_avx512_silu_f16_16n + ); + bench_pair!( + c, + "gelu_f16", + tract_linalg::generic::gelu::HGelu8, + tract_linalg::x86_64_fma::act_f16::x86_64_avx512_gelu_f16_16n + ); +} + +criterion_group!(g, benches); +criterion_main!(g); diff --git a/linalg/benches/activations_avx512_fp16.rs b/linalg/benches/activations_avx512_fp16.rs new file mode 100644 index 0000000000..c754b79827 --- /dev/null +++ b/linalg/benches/activations_avx512_fp16.rs @@ -0,0 +1,68 @@ +// Microbench: AVX-512_FP16 native f16 element-wise activations vs the +// f32-roundtrip versions in `act_f16.rs` (which were the AVX-512 f16 path +// before native f16 ISA was available). Both run on 64-byte-aligned, 1024- +// element buffers — same workload as the existing activations_avx512_f16 +// bench, just adding the native-fp16 column. + +use criterion::*; +use tract_data::prelude::*; +use tract_linalg::element_wise::ElementWiseKer; + +const N: usize = 1024; + +fn aligned_input() -> Tensor { + let mut t = unsafe { Tensor::uninitialized_aligned::(&[N], 64).unwrap() }; + let s = unsafe { t.as_slice_mut_unchecked::() }; + for (i, x) in s.iter_mut().enumerate() { + *x = f16::from_f32((i as f32 / 10.0).sin() * 5.0); + } + t +} + +macro_rules! bench_triple { + ($c:expr, $name:expr, $pred:ty, $roundtrip:ty, $native:ty $(, $param:expr)?) => {{ + let mut group = $c.benchmark_group($name); + group.throughput(Throughput::Elements(N as u64)); + let mut tg = aligned_input(); + let sg = unsafe { tg.as_slice_mut_unchecked::() }; + group.bench_function("generic", |b| { + b.iter(|| <$pred>::run(sg, ($($param)?))) + }); + if std::is_x86_feature_detected!("avx512f") { + let mut tr = aligned_input(); + let sr = unsafe { tr.as_slice_mut_unchecked::() }; + group.bench_function("avx512_f32roundtrip", |b| { + b.iter(|| <$roundtrip>::run(sr, ($($param)?))) + }); + } + if std::is_x86_feature_detected!("avx512fp16") { + let mut tn = aligned_input(); + let sn = unsafe { tn.as_slice_mut_unchecked::() }; + group.bench_function("avx512fp16_native", |b| { + b.iter(|| <$native>::run(sn, ($($param)?))) + }); + } + group.finish(); + }}; +} + +fn benches(c: &mut Criterion) { + bench_triple!( + c, + "hardswish_f16", + tract_linalg::generic::hardswish::HHardSwish8, + tract_linalg::x86_64_fma::act_f16::x86_64_avx512_hardswish_f16_64n, + tract_linalg::x86_64_fma::act_f16_fp16::x86_64_avx512fp16_hardswish_f16_128n + ); + bench_triple!( + c, + "leaky_relu_f16", + tract_linalg::generic::leaky_relu::HLeakyRelu8, + tract_linalg::x86_64_fma::act_f16::x86_64_avx512_leaky_relu_f16_64n, + tract_linalg::x86_64_fma::act_f16_fp16::x86_64_avx512fp16_leaky_relu_f16_128n, + f16::from_f32(0.1) + ); +} + +criterion_group!(g, benches); +criterion_main!(g); diff --git a/linalg/benches/erf.rs b/linalg/benches/erf.rs new file mode 100644 index 0000000000..0fa4661f93 --- /dev/null +++ b/linalg/benches/erf.rs @@ -0,0 +1,38 @@ +// Microbenchmark: AVX-512 (zmm, 16-wide) erf kernel vs the generic scalar +// SErf4 (no FMA predecessor exists on x86). All buffers are 64-byte aligned +// (AVX-512 alignment_bytes) and a multiple of 64 elements so the kernel's +// nr() = 64 divides the length. + +use criterion::*; +use tract_data::prelude::*; +use tract_linalg::element_wise::ElementWiseKer; + +const N: usize = 1024; + +fn aligned_input() -> Tensor { + let mut t = unsafe { Tensor::uninitialized_aligned::(&[N], 64).unwrap() }; + let s = unsafe { t.as_slice_mut_unchecked::() }; + for (i, x) in s.iter_mut().enumerate() { + *x = (i as f32 / 10.0).sin() * 5.0; + } + t +} + +fn erf_f32(c: &mut Criterion) { + let mut g = c.benchmark_group("erf_f32"); + g.throughput(Throughput::Elements(N as u64)); + let mut tp = aligned_input(); + let sp = unsafe { tp.as_slice_mut_unchecked::() }; + g.bench_function("generic", |b| b.iter(|| tract_linalg::generic::SErf4::run(sp, ()))); + if std::is_x86_feature_detected!("avx512f") { + let mut ta = aligned_input(); + let sa = unsafe { ta.as_slice_mut_unchecked::() }; + g.bench_function("avx512", |b| { + b.iter(|| tract_linalg::x86_64_fma::erf::x86_64_avx512_erf_f32_64n::run(sa, ())) + }); + } + g.finish(); +} + +criterion_group!(g, erf_f32); +criterion_main!(g); diff --git a/linalg/benches/qmmm_i8.rs b/linalg/benches/qmmm_i8.rs new file mode 100644 index 0000000000..38c234b771 --- /dev/null +++ b/linalg/benches/qmmm_i8.rs @@ -0,0 +1,56 @@ +// int8 -> i32 GEMM (qmmm_i32) microbench. A/B the SME SMOPA kernel vs the NEON +// fallback by running twice: default (SME) vs TRACT_SME_DISABLE=1 (arm64simd 8x8). +extern crate criterion; +use criterion::*; +use tract_data::internal::*; +use tract_linalg::mmm::{AsInputValue, FusedSpec}; + +use DatumType::I32; + +fn qmmm(be: &mut criterion::Bencher, &(m, k, n): &(usize, usize, usize)) { + unsafe { + let mmm = tract_linalg::ops().mmm(I32, Some(m), Some(k), Some(n)).unwrap(); + // packing index 1 == i8i8 for both sme_qmmm_i32_32x32 and arm64simd_mmm_i32_8x8. + let a = Tensor::zero::(&[m, k]).unwrap(); + let b = Tensor::zero::(&[k, n]).unwrap(); + let packing = &mmm.packings()[1]; + let pa = packing.0.prepare_one(&a, 1, 0).unwrap(); + let pb = packing.1.prepare_one(&b, 0, 1).unwrap(); + let mut c = Tensor::zero::(&[m, n]).unwrap(); + be.iter(move || { + mmm.run( + m, + n, + &[ + FusedSpec::AddMatMul { + a: AsInputValue::Borrowed(&*pa), + b: AsInputValue::Borrowed(&*pb), + packing: 1, + }, + FusedSpec::Store(mmm.c_view(Some(0), Some(1)).wrap(&c.view_mut())), + ], + ) + }); + } +} + +fn bench(c: &mut Criterion) { + let mut g = c.benchmark_group("qmmm_i8"); + g.sample_size(20); + for &shape in &[ + (256usize, 256usize, 256usize), + (512, 512, 512), + (1024, 1024, 1024), + (128, 768, 768), + (384, 768, 768), + (64, 2048, 2048), + ] { + let (m, k, n) = shape; + g.throughput(Throughput::Elements((m * k * n) as u64)); + g.bench_function(format!("{m}x{k}x{n}"), |be| qmmm(be, &shape)); + } + g.finish(); +} + +criterion::criterion_group!(benches, bench); +criterion::criterion_main!(benches); diff --git a/linalg/benches/rms_norm.rs b/linalg/benches/rms_norm.rs new file mode 100644 index 0000000000..7119ba38b5 --- /dev/null +++ b/linalg/benches/rms_norm.rs @@ -0,0 +1,55 @@ +// Microbench: fused RmsNorm vs the 4-call composition that tract-core currently +// uses (MeanOfSquares + Add + Rsqrt + Mul). The composition is reconstructed +// inline here in the same shape as `core::ops::nn::rms_norm::RmsNorm::eval` +// drives it. Both versions run on a 64-byte-aligned f32 row. + +use criterion::*; +use tract_data::prelude::*; + +fn aligned_row(n: usize) -> Tensor { + let mut t = unsafe { Tensor::uninitialized_aligned::(&[n], 64).unwrap() }; + let s = unsafe { t.as_slice_mut_unchecked::() }; + for (i, x) in s.iter_mut().enumerate() { + *x = (i as f32 / 10.0).sin() * 5.0; + } + t +} + +#[inline(never)] +fn composed_rms_norm(buf: &mut [f32], eps: f32) { + // Same shape as tract-core's RmsNorm::eval: separate passes for sum-of-squares, + // mean, +eps, rsqrt, multiply — each writing/reading the row once. + let mut sum_sq = 0.0_f32; + for &x in buf.iter() { + sum_sq += x * x; + } + let mean_sq = sum_sq / buf.len() as f32; + let added = mean_sq + eps; + let inv_std = added.sqrt().recip(); + for x in buf.iter_mut() { + *x *= inv_std; + } +} + +fn rms_norm(c: &mut Criterion) { + for &n in &[1024usize, 2048, 4096] { + let id = format!("{n}"); + let mut g = c.benchmark_group(format!("rms_norm_f32/{id}")); + g.throughput(Throughput::Elements(n as u64)); + let mut t = aligned_row(n); + let s = unsafe { t.as_slice_mut_unchecked::() }; + g.bench_function("composed", |b| b.iter(|| composed_rms_norm(s, 1e-5))); + g.bench_function("generic", |b| { + b.iter(|| tract_linalg::generic::rms_norm::rms_norm_f32(s, 1e-5)) + }); + if std::is_x86_feature_detected!("avx512f") { + g.bench_function("avx512", |b| { + b.iter(|| tract_linalg::x86_64_fma::rms_norm::rms_norm_f32(s, 1e-5)) + }); + } + g.finish(); + } +} + +criterion_group!(g, rms_norm); +criterion_main!(g); diff --git a/linalg/benches/softmax.rs b/linalg/benches/softmax.rs index 2c9380c928..10a2dd89bf 100644 --- a/linalg/benches/softmax.rs +++ b/linalg/benches/softmax.rs @@ -1,7 +1,7 @@ use criterion::*; use tract_data::prelude::*; use tract_linalg::element_wise::ElementWiseKer; -use tract_linalg::generic::reduce::softmax_l2::SSoftMaxL2; +use tract_linalg::generic::reduce::softmax_l2::{HSoftMaxL2, SSoftMaxL2}; use tract_linalg::reduce::{MapReduceKer, ReduceKer}; #[inline(never)] @@ -42,10 +42,19 @@ fn rust_f32(slice: &mut [f32]) { fn softmax_f32(c: &mut Criterion) { let mut group = c.benchmark_group("softmax_f32"); - group.throughput(Throughput::Elements(1500)); - let mut input = unsafe { Tensor::uninitialized_aligned::(&[1500], 16).unwrap() }; + // 1536 = 24*64 = 48*32: a multiple of both the FMA (32) and AVX-512 (64) tile + // widths, 64-byte aligned so both kernels run entirely on their fast aligned + // path (no prefix/suffix scalar fixup) for a fair before/after comparison. + group.throughput(Throughput::Elements(1536)); + let mut input = unsafe { Tensor::uninitialized_aligned::(&[1536], 64).unwrap() }; let mut plain = input.try_as_plain_mut().unwrap(); let input = plain.as_slice_mut::().unwrap(); + // Deterministic finite values so every kernel sees identical, well-behaved + // input (uninitialized memory could contain NaN/huge values that perturb the + // fast-compact-exp int conversion and skew the comparison). + for (i, x) in input.iter_mut().enumerate() { + *x = ((i % 97) as f32) * 0.1 - 5.0; + } group.bench_function("rust", |b| b.iter(|| rust_f32(input))); group.bench_function("loop1/naive", |b| b.iter(|| loop1_f32_naive(input))); group.bench_function("loop1/generic", |b| { @@ -57,6 +66,14 @@ fn softmax_f32(c: &mut Criterion) { tract_linalg::x86_64_fma::max::x86_64_fma_max_f32_32n::red().run(input).unwrap(); }) }); + #[cfg(target_arch = "x86_64")] + if is_x86_feature_detected!("avx512f") { + group.bench_function("loop1/avx512", |b| { + b.iter(|| { + tract_linalg::x86_64_fma::max::x86_64_avx512_max_f32_64n::red().run(input).unwrap(); + }) + }); + } #[cfg(target_arch = "aarch64")] group.bench_function("loop1/intr", |b| { b.iter(|| { @@ -75,6 +92,16 @@ fn softmax_f32(c: &mut Criterion) { .unwrap() }); }); + #[cfg(target_arch = "x86_64")] + if is_x86_feature_detected!("avx512f") { + group.bench_function("loop2/avx512", |b| { + b.iter(|| { + tract_linalg::x86_64_fma::softmax::x86_64_avx512_softmax2_fastcompact_f32_64n::red() + .run_with_params(input, 10.) + .unwrap() + }); + }); + } #[cfg(target_arch = "aarch64")] group.bench_function("loop2/iasm", |b| { b.iter(|| { @@ -107,5 +134,31 @@ fn softmax_f32(c: &mut Criterion) { }); } -criterion_group!(benches, softmax_f32); +fn softmax_f16(c: &mut Criterion) { + let mut group = c.benchmark_group("softmax_f16"); + // 1536 = 64*24 (multiple of avx512 f16 nr=64 and generic h nr=8). + const N: usize = 1536; + group.throughput(Throughput::Elements(N as u64)); + let mut input = unsafe { Tensor::uninitialized_aligned::(&[N], 64).unwrap() }; + let mut plain = input.try_as_plain_mut().unwrap(); + let input = plain.as_slice_mut::().unwrap(); + for (i, x) in input.iter_mut().enumerate() { + *x = f16::from_f32((i as f32 / 10.0).sin() * 5.0); + } + group.bench_function("loop2/generic", |b| { + b.iter(|| HSoftMaxL2::red().run_with_params(input, f16::from_f32(10.0))) + }); + #[cfg(target_arch = "x86_64")] + if std::is_x86_feature_detected!("avx512f") { + group.bench_function("loop2/avx512", |b| { + b.iter(|| { + tract_linalg::x86_64_fma::softmax::x86_64_avx512_softmax2_fastcompact_f16_64n::red() + .run_with_params(input, f16::from_f32(10.0)) + .unwrap() + }); + }); + } +} + +criterion_group!(benches, softmax_f32, softmax_f16); criterion_main!(benches); diff --git a/linalg/benches/vnni_i32.rs b/linalg/benches/vnni_i32.rs new file mode 100644 index 0000000000..59e6f01676 --- /dev/null +++ b/linalg/benches/vnni_i32.rs @@ -0,0 +1,63 @@ +#![allow(dead_code)] +// Kernel-level benchmark: AVX-512 VNNI int8 GEMM (avx512vnni_mmm_i32_8x8, VPDPBUSD +// over the K=4-inner PackedI8K4 layout) vs the AVX2 int8 path (avx2_mmm_i32_8x8, +// vpmaddubsw-style widening). Both run the i8i8 packing (index 1) over the same +// M/K/N so the only difference is the matmul inner loop. +use criterion::*; +use tract_data::internal::*; +use tract_linalg::mmm::{AsInputValue, FusedSpec, MatMatMul}; + +fn run_kernel(be: &mut Bencher, mmm: &dyn MatMatMul, m: usize, k: usize, n: usize) { + let a = Tensor::zero_dt(DatumType::I8, &[m, k]).unwrap(); + let b = Tensor::zero_dt(DatumType::I8, &[k, n]).unwrap(); + let (pack_a, pack_b) = &mmm.packings()[1]; + let pa = pack_a.prepare_one(&a, 1, 0).unwrap(); + let pb = pack_b.prepare_one(&b, 0, 1).unwrap(); + let mut scratch = unsafe { mmm.allocate_scratch_space() }; + be.iter_custom(|iters| { + let mut dur = std::time::Duration::default(); + for _ in 0..iters { + let t = std::time::Instant::now(); + unsafe { + mmm.run_with_scratch_space( + m, + n, + scratch.as_mut(), + &[FusedSpec::AddMatMul { + a: AsInputValue::Borrowed(&*pa), + b: AsInputValue::Borrowed(&*pb), + packing: 1, + }], + ) + .unwrap() + }; + dur += t.elapsed(); + } + dur + }); +} + +fn benches(c: &mut Criterion) { + if !std::is_x86_feature_detected!("avx512vnni") { + eprintln!("avx512vnni not available, skipping"); + return; + } + use tract_linalg::x86_64_fma::mmm::*; + for &(m, k, n) in + &[(64usize, 256usize, 64usize), (256, 256, 256), (512, 512, 512), (1024, 1024, 64)] + { + let id = format!("{m}x{k}x{n}"); + let mut g = c.benchmark_group("vnni_i32/packed_packed"); + g.throughput(Throughput::Elements((m * k * n) as u64)); + g.bench_with_input(BenchmarkId::new("avx2", &id), &(m, k, n), |b, &(m, k, n)| { + run_kernel(b, &*avx2_mmm_i32_8x8.mmm(), m, k, n) + }); + g.bench_with_input(BenchmarkId::new("avx512vnni", &id), &(m, k, n), |b, &(m, k, n)| { + run_kernel(b, &*avx512vnni_mmm_i32_8x8.mmm(), m, k, n) + }); + g.finish(); + } +} + +criterion_group!(g, benches); +criterion_main!(g); diff --git a/linalg/build.rs b/linalg/build.rs index 9e2325815b..600a2115f6 100644 --- a/linalg/build.rs +++ b/linalg/build.rs @@ -37,6 +37,37 @@ fn assembler_supports_sme() -> bool { .is_ok() } +// Probe whether the target assembler can encode FEAT_DotProd `sdot` (the +// indexed int8 form used by arm64simd_mmm_i32_8x8_dot). Old binutils — notably +// the Debian stretch aarch64 cross-toolchain in CI — predate FEAT_DotProd and +// reject `.cpu ...+dotprod` / `sdot` outright. When the probe fails we skip the +// SDOT kernel and the `tract_arm64_dotprod` cfg; the runtime falls back to the +// SMLAL 8x8 i32 kernel. +fn assembler_supports_dotprod() -> bool { + cc::Build::new() + .file("arm64/arm64simd/dummy_dotprod.S") + .cargo_metadata(false) + .cargo_warnings(false) + .warnings(false) + .try_compile("tract_dotprod_probe") + .is_ok() +} + +// Probe whether the target assembler can encode `vpdpbusd ymm` (AVX-512 VNNI +// with AVX-512 VL, i.e. the 256-bit form). binutils gained this in ~2.30 +// (2018); the Debian stretch toolchain ships 2.28 and rejects the mnemonic. +// When the probe fails we skip the VNNI kernel and the `tract_avx512vnni` cfg; +// the runtime falls back to the AVX2 i32 path. +fn assembler_supports_avx512vnni() -> bool { + cc::Build::new() + .file("x86_64/avx512vnni/dummy_vnni.S") + .cargo_metadata(false) + .cargo_warnings(false) + .warnings(false) + .try_compile("tract_avx512vnni_probe") + .is_ok() +} + fn include_sve() -> bool { // SVE/SVE2 lives on ARMv9 server/mobile cores (Neoverse V1+/N2+, Cortex-X2+, // Graviton 3/4) — Linux aarch64. No Apple silicon has SVE. @@ -130,10 +161,19 @@ fn main() { println!("cargo:rustc-check-cfg=cfg(tract_sme)"); // Set below only when include_sve() and the SVE compiler probe both pass. println!("cargo:rustc-check-cfg=cfg(tract_sve)"); + // Set below only when the aarch64 assembler probe for `sdot` passes. + println!("cargo:rustc-check-cfg=cfg(tract_arm64_dotprod)"); + // Set below only when the x86_64 assembler probe for vpdpbusd ymm passes. + println!("cargo:rustc-check-cfg=cfg(tract_avx512vnni)"); match arch.as_ref() { "x86_64" => { let mut files = preprocess_files("x86_64/fma", &[], &suffix, false); + // The VNNI kernel is compiled separately (conditional on a probe) to + // avoid breaking old assemblers. Remove it from the main file list. + files.retain(|f| { + !f.file_name().and_then(|n| n.to_str()).map_or(false, |n| n.contains("avx512vnni")) + }); files.extend(preprocess_files("x86_64/avx512", &[], &suffix, false)); if os == "windows" { @@ -184,6 +224,20 @@ fn main() { } else { cc::Build::new().files(files).flag("-mfma").compile("x86_64_fma"); } + // VNNI kernel compiled separately so old assemblers (binutils < 2.30, + // e.g. Debian stretch) that can't encode `vpdpbusd ymm` don't break + // the whole x86_64 build. The `tract_avx512vnni` cfg gates the + // matching Rust extern declarations and dispatch registration. + // + // The template stays in x86_64/fma/ (alongside dispatcher.j2 and the + // other partials it includes) so the jinja env can resolve its includes. + if assembler_supports_avx512vnni() { + let tmpl = path::Path::new("x86_64/fma/avx512vnni_mmm_i32_8x8.S.j2"); + let out = out_dir.join(format!("avx512vnni_mmm_i32_8x8_{suffix}.S")); + preprocess_file(tmpl, &out, &[], &suffix, false); + cc::Build::new().file(&out).flag("-mfma").compile("x86_64_avx512vnni"); + println!("cargo:rustc-cfg=tract_avx512vnni"); + } } "arm" | "armv7" => { let files = preprocess_files("arm32/armvfpv2", &[], &suffix, false); @@ -197,13 +251,30 @@ fn main() { cc::Build::new().files(files).flag("-marm").flag("-mfpu=neon").compile("armv7neon"); } "aarch64" => { - let files = preprocess_files( + let mut files = preprocess_files( "arm64/arm64simd", &[("core", vec!["a53", "a55", "gen"])], &suffix, false, ); + // The SDOT kernel is compiled separately (conditional on a probe) so + // old assemblers (binutils < 2.30, e.g. Debian stretch) that can't + // encode `sdot` don't break the whole arm64simd build. Remove it + // from the main file list. + files.retain(|f| { + !f.file_name().and_then(|n| n.to_str()).map_or(false, |n| n.contains("_dot")) + }); cc::Build::new().files(files).compile("arm64simd"); + // The template stays in arm64/arm64simd/ (alongside the jinja partials + // it includes) so the env can resolve its includes. The + // `tract_arm64_dotprod` cfg gates the matching Rust extern + dispatch. + if assembler_supports_dotprod() { + let tmpl = path::Path::new("arm64/arm64simd/arm64simd_mmm_i32_8x8_dot.S.j2"); + let out = out_dir.join(format!("arm64simd_mmm_i32_8x8_dot_{suffix}.S")); + preprocess_file(tmpl, &out, &[], &suffix, false); + cc::Build::new().file(&out).compile("arm64simd_dot"); + println!("cargo:rustc-cfg=tract_arm64_dotprod"); + } if include_amx() { let files = preprocess_files("arm64/apple_amx", &[], &suffix, false); cc::Build::new().files(files).compile("appleamx"); diff --git a/linalg/src/arm64.rs b/linalg/src/arm64.rs index 47f7cd4ccf..a3c9beb4c9 100644 --- a/linalg/src/arm64.rs +++ b/linalg/src/arm64.rs @@ -189,6 +189,40 @@ pub fn has_fp16() -> bool { || *HAS_FP16 } +// FEAT_DotProd (SDOT/UDOT), ARMv8.2. TRACT_DOTPROD_DISABLE=1 forces it off so +// callers can A/B the SDOT kernel against the SMLAL 8x8 fallback on one binary. +#[cfg(target_os = "macos")] +pub fn has_dotprod() -> bool { + // Every Apple arm64 CPU (M1+/A11+) implements FEAT_DotProd. + std::env::var_os("TRACT_DOTPROD_DISABLE").is_none() +} + +#[cfg(target_os = "linux")] +pub fn has_dotprod() -> bool { + if std::env::var_os("TRACT_DOTPROD_DISABLE").is_some() { + return false; + } + // HWCAP_ASIMDDP = 1 << 20 on aarch64. + const HWCAP_ASIMDDP: u64 = 1 << 20; + const AT_HWCAP: u64 = 16; + unsafe extern "C" { + fn getauxval(t: u64) -> u64; + } + unsafe { (getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0 } +} + +#[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "ios")))] +pub fn has_dotprod() -> bool { + false +} + +#[cfg(target_os = "ios")] +pub fn has_dotprod() -> bool { + // A11+ (iPhone10,1+) implement FEAT_DotProd. + std::env::var_os("TRACT_DOTPROD_DISABLE").is_none() + && IPHONE_MODEL_MAJOR.map(|it| it >= 10).unwrap_or(false) +} + #[target_feature(enable = "fp16")] #[inline] pub unsafe fn add_f16(a: f16, b: f16) -> f16 { @@ -351,7 +385,19 @@ pub fn plug(ops: &mut Ops) { arm64fp16::plug(ops); } - ops.qmmm_i32 = Box::new(|_, _, _| arm64simd_mmm_i32_8x8.mmm()); + // SDOT (~4x the SMLAL 8x8) when FEAT_DotProd is present, else the SMLAL 8x8 fallback. + // The SDOT kernel only exists when the assembler could encode `sdot` + // (`tract_arm64_dotprod`, set by build.rs); otherwise always use the SMLAL 8x8. + #[cfg(tract_arm64_dotprod)] + if has_dotprod() { + ops.qmmm_i32 = Box::new(|_, _, _| arm64simd_mmm_i32_8x8_dot.mmm()); + } else { + ops.qmmm_i32 = Box::new(|_, _, _| arm64simd_mmm_i32_8x8.mmm()); + } + #[cfg(not(tract_arm64_dotprod))] + { + ops.qmmm_i32 = Box::new(|_, _, _| arm64simd_mmm_i32_8x8.mmm()); + } ops.qmmv_i32 = Box::new(|_, _| arm64simd_mmm_i32_64x1.mmm()); ops.mmv_f32 = match *KIND { Kind::CortexA53 => Box::new(|_, _| arm64simd_mmm_f32_64x1_a53.mmm()), diff --git a/linalg/src/arm64/arm64simd.rs b/linalg/src/arm64/arm64simd.rs index f90d92926f..2f3ec68389 100644 --- a/linalg/src/arm64/arm64simd.rs +++ b/linalg/src/arm64/arm64simd.rs @@ -84,6 +84,21 @@ MMMExternKernel!(arm64simd_mmm_i32_8x8(8, 8)@(16, 16) store(i8) ); +// SDOT (FEAT_DotProd) variant: 4-K reduction per instruction (~4x the SMLAL +// 8x8 above). Uses the K=4-inner PackedI8K4 packing; identical v16..v31 tile +// layout, so it reuses all the i32 fuse/store/q_scale machinery. +// +// Gated on `tract_arm64_dotprod` (set by build.rs when the assembler can encode +// `sdot`; binutils < 2.30 cannot). On old toolchains the kernel is omitted and +// dispatch falls back to the SMLAL 8x8 i32 kernel. +#[cfg(tract_arm64_dotprod)] +MMMExternKernel!(arm64simd_mmm_i32_8x8_dot(8, 8)@(16, 16) + where(super::has_dotprod) + packing[1] = i8i8 => |k| k.with_packing(crate::pack::PackedI8K4::new(8), crate::pack::PackedI8K4::new(8)); + quality(ManuallyOptimized) + store(i8) +); + MMMExternKernel!(arm64simd_mmm_i32_64x1(64, 1)@(16, 1) packing[1] = i8i8 => |k| k.with_packing(PackedFormat::new(DatumType::I8, 64,16), PackedFormat::new(DatumType::I8, 1, 1)); quality(ManuallyOptimized) @@ -112,6 +127,8 @@ pub fn plug(ops: &mut Ops) { arm64simd_mmm_i32_8x8.mmm(), arm64simd_mmm_i32_64x1.mmm(), ]); + #[cfg(tract_arm64_dotprod)] + ops.mmm_impls.push(arm64simd_mmm_i32_8x8_dot.mmm()); panel_extract::plug(ops); } diff --git a/linalg/src/arm64/sme.rs b/linalg/src/arm64/sme.rs index 5b8b194d6c..a6adb75d68 100644 --- a/linalg/src/arm64/sme.rs +++ b/linalg/src/arm64/sme.rs @@ -18,6 +18,13 @@ const CAN_FUSE: fn(&FusedSpec) -> bool = |f| { const SME: fn() -> bool = has_sme; const SME2: fn() -> bool = has_sme2; +// The SMOPA i32 kernel implements the quant fuse ops (QScale / RoundingShiftRight +// / ShiftLeft) bit-exactly; only LeakyRelu is unsupported (kernel returns 1). +const CAN_FUSE_I32: fn(&FusedSpec) -> bool = |f| !matches!(f, FusedSpec::LeakyRelu(_)); + +MMMExternKernel!(sme_qmmm_i32_32x32(32,32)@(128,128) where(SME2) can_fuse(CAN_FUSE_I32) + packing[1] = i8i8 => |k| k.with_packing(crate::pack::PackedI8K4::new(32), crate::pack::PackedI8K4::new(32)); + quality(ManuallyOptimized) store(i8)); // Streaming vector length in bytes, read via `RDSVL x0, #1` (encoding // 0x04bf5820). RDSVL is legal in non-streaming mode, but is UNDEFINED @@ -188,7 +195,8 @@ pub fn plug(ops: &mut Ops) { if has_sme2() { log::info!("SME2 GEMV optimisation activated"); ops.mmv_f32 = Box::new(|_, _| sme_mmv_f32_64x1.mmm()); - ops.mmm_impls.extend_from_slice(&[sme_mmv_f32_64x1.mmm()]); + ops.qmmm_i32 = Box::new(|_, _, _| sme_qmmm_i32_32x32.mmm()); + ops.mmm_impls.extend_from_slice(&[sme_mmv_f32_64x1.mmm(), sme_qmmm_i32_32x32.mmm()]); } if !has_sme() && !has_sme2() { log::info!("No SME optimisation"); diff --git a/linalg/src/frame/erf.rs b/linalg/src/frame/erf.rs new file mode 100644 index 0000000000..4a49a0686d --- /dev/null +++ b/linalg/src/frame/erf.rs @@ -0,0 +1,82 @@ +#[allow(unused_macros)] +macro_rules! erf_impl { + ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $cond: expr) => { + ew_impl!($ti, $func, $nr, $alignment_items); + #[cfg(test)] + paste! { + mod [] { + use super::*; + erf_frame_tests!($cond, $ti, $func); + } + } + }; +} + +#[cfg(test)] +#[macro_use] +pub mod test { + use crate::LADatum; + use crate::frame::element_wise::*; + use num_traits::{AsPrimitive, Float}; + use proptest::test_runner::TestCaseResult; + + #[macro_export] + macro_rules! erf_frame_tests { + ($cond:expr, $t: ty, $ker:ty) => { + proptest::proptest! { + #[test] + fn prop(xs in proptest::collection::vec(-5f32..5.0, 0..100)) { + if $cond { + $crate::frame::erf::test::test_erf::<$ker, $t>(&*xs).unwrap() + } + } + } + #[test] + fn trivial() { + if $cond { + $crate::frame::erf::test::test_erf::<$ker, $t>(&[ + -5f32, -2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0, 5.0, + ]) + .unwrap(); + } + } + #[test] + fn zeros() { + if $cond { + $crate::frame::erf::test::test_erf::<$ker, $t>(&[0.0; 16]).unwrap(); + } + } + }; + } + + pub fn test_erf, T: LADatum + Float>(values: &[f32]) -> TestCaseResult + where + f32: AsPrimitive, + T: AsPrimitive, + { + let data = tract_data::prelude::tensor1(values); + let data = data.cast_to::().unwrap(); + let data = data.try_as_plain().unwrap().as_slice::().unwrap(); + crate::frame::element_wise::test::test_element_wise::(data, |x: T| { + // Abramowitz & Stegun 7.1.26 six-coefficient approximation, mirroring + // generic/erf.rs::serf so the test reference matches the production scalar path. + const A1: f32 = 0.0705230784; + const A2: f32 = 0.0422820123; + const A3: f32 = 0.0092705272; + const A4: f32 = 0.0001520143; + const A5: f32 = 0.0002765672; + const A6: f32 = 0.0000430638; + let x: f32 = x.as_(); + let signum = x.signum(); + let abs = x.abs(); + let y = A6 * abs; + let y = (A5 + y) * abs; + let y = (A4 + y) * abs; + let y = (A3 + y) * abs; + let y = (A2 + y) * abs; + let y = (A1 + y) * abs; + let y = 1.0 - (y + 1.0).powi(16).recip(); + y.copysign(signum).as_() + }) + } +} diff --git a/linalg/src/frame/mmm/mod.rs b/linalg/src/frame/mmm/mod.rs index bbe84cf8b9..bafc69826e 100644 --- a/linalg/src/frame/mmm/mod.rs +++ b/linalg/src/frame/mmm/mod.rs @@ -217,10 +217,19 @@ impl MatMatMul for K { prefer_row = (!col) as usize; } } + // k drives the single-thread cache-block size; read it from the + // first AddMatMul's packed input (0 if none → max block). + let k = non_linear + .iter() + .find_map(|f| match f { + FusedSpec::AddMatMul { a, .. } => Some(a.k()), + _ => None, + }) + .unwrap_or(0); if prefer_col > prefer_row { - run_with_scratch_space_col_outer(self, m, n, scratch, non_linear) + run_with_scratch_space_col_outer(self, m, n, k, scratch, non_linear) } else { - run_with_scratch_space_row_outer(self, m, n, scratch, non_linear) + run_with_scratch_space_row_outer(self, m, n, k, scratch, non_linear) } } } @@ -270,23 +279,153 @@ unsafe fn run_with_scratch_space_vec( } } +/// Upper bound on the single-thread panel-block edge (matches the multithread +/// `chunk_grid` default). +const ST_BLK_MAX: usize = 16; + +#[cfg(target_os = "linux")] +fn parse_cache_size(s: &str) -> usize { + let s = s.trim(); + let (num, mult) = if let Some(n) = s.strip_suffix(['K', 'k']) { + (n, 1024) + } else if let Some(n) = s.strip_suffix(['M', 'm']) { + (n, 1024 * 1024) + } else { + (s, 1) + }; + num.trim().parse::().unwrap_or(0) * mult +} + +/// Best-effort L2 data-cache size in bytes (per perf-core / cluster); 0 if +/// unknown. Cached. Used to size the single-thread cache-block budget so it is +/// correct across hardware instead of a hard-coded constant. +fn detect_l2_bytes() -> usize { + static L2: std::sync::OnceLock = std::sync::OnceLock::new(); + *L2.get_or_init(|| { + #[cfg(target_os = "macos")] + { + let sysctl = |k: &str| -> Option { + let o = std::process::Command::new("sysctl").arg("-n").arg(k).output().ok()?; + if !o.status.success() { + return None; + } + String::from_utf8_lossy(&o.stdout).trim().parse().ok() + }; + // Prefer the performance-core L2 on hybrid Apple Silicon. + sysctl("hw.perflevel0.l2cachesize").or_else(|| sysctl("hw.l2cachesize")).unwrap_or(0) + } + #[cfg(target_os = "linux")] + { + // index2/index3 is typically the unified L2 (index0/1 are L1 d/i). + for idx in [2usize, 3] { + if let Ok(s) = std::fs::read_to_string(format!( + "/sys/devices/system/cpu/cpu0/cache/index{idx}/size" + )) { + let b = parse_cache_size(s.trim()); + if b > 0 { + return b; + } + } + } + 0 + } + #[cfg(not(any(target_os = "macos", target_os = "linux")))] + { + 0 + } + }) +} + +/// Working-set budget (bytes) for the single-thread cache-block: ~a third of L2 +/// (leaving room for the C accumulator tile + packing metadata). Conservative +/// 256 KiB fallback when L2 is unknown (WASM/Windows/BSD) ⇒ small blocks ≈ the +/// naive loop, so it can never over-block a cache it can't see. +fn block_budget_bytes() -> usize { + let l2 = detect_l2_bytes(); + if l2 == 0 { 256 * 1024 } else { (l2 / 3).clamp(64 * 1024, 8 * 1024 * 1024) } +} + +/// Cache-adaptive panel-block edge: large enough to amortise streaming, small +/// enough that the block's A+B sub-panels (`~blk·(mr+nr)·k·elem_bytes`) stay +/// L2-resident at the given `k`. Capped at [`ST_BLK_MAX`]; the floor of 1 +/// degrades exactly to the naive loop, so an unknown/small cache can never +/// over-block (regression-safe). The budget is **cache-size derived** (not a +/// hard-coded constant), so it is correct across hardware. +#[inline] +fn st_block_edge(mr: usize, nr: usize, k: usize, elem_bytes: usize) -> usize { + if k == 0 { + return ST_BLK_MAX; + } + let per_blk = ((mr + nr) * k * elem_bytes.max(1)).max(1); + (block_budget_bytes() / per_blk).clamp(1, ST_BLK_MAX) +} + +/// Single-thread tile walk over the `m_panels × n_panels` grid, blocked into +/// cache-sized panel blocks for locality (the naive nested loop re-streams the +/// whole inner operand per outer panel at large k; the multithread path already +/// blocks this way via `chunk_grid`). `col_outer` selects the within-block inner +/// order (B-reuse vs A-reuse). Reordering independent tiles changes no result — +/// bit-exact with the naive loop. +#[inline] +unsafe fn run_single_thread_blocked( + ker: &K, + m_panels: usize, + n_panels: usize, + k: usize, + col_outer: bool, + scratch: &mut ScratchSpaceImpl, + non_linear: &[FusedSpec], +) -> TractResult<()> { + unsafe { + let blk = st_block_edge(ker.mr(), ker.nr(), k, K::Acc::datum_type().size_of()); + scratch.run_in_tls_scope(|scratch, tls| { + let mut jb = 0; + while jb < n_panels { + let jb_end = (jb + blk).min(n_panels); + let mut ja = 0; + while ja < m_panels { + let ja_end = (ja + blk).min(m_panels); + if col_outer { + for ib in jb..jb_end { + for ia in ja..ja_end { + scratch.run_one_tile(ker, non_linear, tls, ia, ib)?; + } + } + } else { + for ia in ja..ja_end { + for ib in jb..jb_end { + scratch.run_one_tile(ker, non_linear, tls, ia, ib)?; + } + } + } + ja = ja_end; + } + jb = jb_end; + } + TractResult::Ok(()) + }) + } +} + unsafe fn run_with_scratch_space_col_outer( ker: &K, m: usize, n: usize, + k: usize, scratch: &mut ScratchSpaceImpl, non_linear: &[FusedSpec], ) -> TractResult<()> { unsafe { match crate::multithread::current_tract_executor() { - Executor::SingleThread => scratch.run_in_tls_scope(|scratch, tls| { - for ib in 0..n.divceil(ker.nr()) { - for ia in 0..m.divceil(ker.mr()) { - scratch.run_one_tile(ker, non_linear, tls, ia, ib)?; - } - } - TractResult::Ok(()) - }), + Executor::SingleThread => run_single_thread_blocked( + ker, + m.divceil(ker.mr()), + n.divceil(ker.nr()), + k, + true, + scratch, + non_linear, + ), #[cfg(feature = "multithread-mm")] Executor::MultiThread(pool) => chunked_dispatch_rayon( Some(&pool), @@ -327,19 +466,21 @@ unsafe fn run_with_scratch_space_row_outer( ker: &K, m: usize, n: usize, + k: usize, scratch: &mut ScratchSpaceImpl, non_linear: &[FusedSpec], ) -> TractResult<()> { unsafe { match crate::multithread::current_tract_executor() { - Executor::SingleThread => scratch.run_in_tls_scope(|scratch, tls| { - for ia in 0..m.divceil(ker.mr()) { - for ib in 0..n.divceil(ker.nr()) { - scratch.run_one_tile(ker, non_linear, tls, ia, ib)?; - } - } - TractResult::Ok(()) - }), + Executor::SingleThread => run_single_thread_blocked( + ker, + m.divceil(ker.mr()), + n.divceil(ker.nr()), + k, + false, + scratch, + non_linear, + ), #[cfg(feature = "multithread-mm")] Executor::MultiThread(pool) => chunked_dispatch_rayon( Some(&pool), diff --git a/linalg/src/frame/mmm/tests/packed_packed.rs b/linalg/src/frame/mmm/tests/packed_packed.rs index 8437990c72..ff27833aa6 100644 --- a/linalg/src/frame/mmm/tests/packed_packed.rs +++ b/linalg/src/frame/mmm/tests/packed_packed.rs @@ -379,3 +379,43 @@ impl PackedPackedProblem { result } } + +// Large-shape frame tests that exercise the single-thread 2D-blocked tile walk +// (`run_single_thread_blocked`): the existing `arbitrary_problem` frame proptests +// only reach 3 panels per dim (m,n < 3·mr), below the ST_BLK=16 blocking +// threshold, so the blocked path was otherwise uncovered. generic_f32_4x4 has +// mr=nr=4, so m,n=80 → 20×20 panels → multiple blocks. Compares the frame +// output against the naive reference (must be bit/approx-exact). +#[cfg(test)] +mod single_thread_blocking { + use super::PackedPackedProblem; + use crate::generic::mmm::generic_f32_4x4; + use tract_data::internal::TractResult; + + fn check_large(m: usize, n: usize, k: usize) -> TractResult<()> { + let a: Vec = (0..m * k).map(|i| ((i * 7 + 3) % 13) as f32 - 6.0).collect(); + let b: Vec = (0..k * n).map(|i| ((i * 5 + 1) % 11) as f32 - 5.0).collect(); + PackedPackedProblem::frame(&*generic_f32_4x4, 0, m, n, a, b).check() + } + + #[test] + fn blocked_80x80() -> TractResult<()> { + check_large(80, 80, 24) // 20×20 panels, multiple ST_BLK blocks + } + #[test] + fn blocked_skew_200x40() -> TractResult<()> { + check_large(200, 40, 8) // 50×10 panels (m-axis chunked) + } + #[test] + fn blocked_40x200() -> TractResult<()> { + check_large(40, 200, 8) // 10×50 panels (n-axis chunked) + } + #[test] + fn blocked_64x64_exact() -> TractResult<()> { + check_large(64, 64, 16) // exactly 16×16 panels (block boundary) + } + #[test] + fn blocked_68x68_offset() -> TractResult<()> { + check_large(68, 68, 10) // 17×17 panels (one full block + a 1-panel remainder) + } +} diff --git a/linalg/src/frame/mod.rs b/linalg/src/frame/mod.rs index 01128bff3c..6082b9671a 100644 --- a/linalg/src/frame/mod.rs +++ b/linalg/src/frame/mod.rs @@ -8,6 +8,8 @@ pub mod unicast; #[macro_use] pub mod by_scalar; #[macro_use] +pub mod erf; +#[macro_use] pub mod gelu; #[macro_use] pub mod hardswish; diff --git a/linalg/src/generic.rs b/linalg/src/generic.rs index bf9a2fcb65..c9e62014f2 100644 --- a/linalg/src/generic.rs +++ b/linalg/src/generic.rs @@ -6,6 +6,7 @@ pub mod leaky_relu; pub mod lut; pub mod mmm; pub mod reduce; +pub mod rms_norm; pub mod rounding; pub mod sigmoid; pub mod silu; diff --git a/linalg/src/generic/erf.rs b/linalg/src/generic/erf.rs index 8f4cdaf43f..30358f30d2 100644 --- a/linalg/src/generic/erf.rs +++ b/linalg/src/generic/erf.rs @@ -49,3 +49,9 @@ impl ElementWiseKer for SErf4 { x.iter_mut().for_each(serf) } } + +#[cfg(test)] +mod test_serf4 { + use super::*; + crate::erf_frame_tests!(true, f32, SErf4); +} diff --git a/linalg/src/generic/rms_norm.rs b/linalg/src/generic/rms_norm.rs new file mode 100644 index 0000000000..40eed4dbfb --- /dev/null +++ b/linalg/src/generic/rms_norm.rs @@ -0,0 +1,67 @@ +/// Generic scalar reference implementation of fused row-wise RmsNorm. +/// out_i = x_i * rsqrt(mean(x_i²) + eps) +/// +/// Replaces tract-core's 4-call composition (`Reducer::MeanOfSquares` + `Add` + +/// `Rsqrt` + `Mul`) with a single 2-pass kernel. Overridden by AVX-512 on +/// x86_64; non-x86 / non-AVX512 hosts keep this scalar version. +pub fn rms_norm_f32(buf: &mut [f32], eps: f32) { + if buf.is_empty() { + return; + } + let n = buf.len() as f32; + let sum_sq: f32 = buf.iter().map(|x| x * x).sum(); + let mean_sq = sum_sq / n; + let inv_std = (mean_sq + eps).sqrt().recip(); + for x in buf.iter_mut() { + *x *= inv_std; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn close_enough(got: f32, want: f32) -> bool { + (got - want).abs() < 1e-5 + } + + #[test] + fn rms_norm_constant() { + // RmsNorm of all-ones with eps=0: mean(1²)=1, rsqrt=1, output=1. + let mut buf = [1.0; 16]; + rms_norm_f32(&mut buf, 0.0); + for v in buf { + assert!(close_enough(v, 1.0), "got {v}, want 1.0"); + } + } + + #[test] + fn rms_norm_one_two_three_four() { + // mean(1+4+9+16)/4 = 7.5, rsqrt(7.5) ≈ 0.3651 + let mut buf = [1.0_f32, 2.0, 3.0, 4.0]; + rms_norm_f32(&mut buf, 0.0); + let inv = (7.5_f32).sqrt().recip(); + for (i, v) in buf.iter().enumerate() { + let want = (i + 1) as f32 * inv; + assert!(close_enough(*v, want), "i={i}: got {v}, want {want}"); + } + } + + #[test] + fn rms_norm_eps_added_under_root() { + // eps inside the sqrt, not added afterward. + let mut buf = [0.0_f32; 4]; + rms_norm_f32(&mut buf, 1e-5); + for v in buf { + // 0 * anything = 0; just verify no NaN/inf. + assert!(v.is_finite()); + assert_eq!(v, 0.0); + } + } + + #[test] + fn rms_norm_empty() { + let mut buf: [f32; 0] = []; + rms_norm_f32(&mut buf, 1e-5); + } +} diff --git a/linalg/src/lib.rs b/linalg/src/lib.rs index f64488bd8e..b70a1d77e7 100644 --- a/linalg/src/lib.rs +++ b/linalg/src/lib.rs @@ -107,6 +107,12 @@ pub struct Ops { Box Box> + Send + Sync>, pub softmax2_fastcompact_f32: Box Box> + Send + Sync>, + + /// Fused row-wise RmsNorm: out_i = x_i * rsqrt(mean(x_i²) + eps). + /// Replaces a 4-call composition (MeanOfSquares + Add + Rsqrt + Mul) with + /// a single 2-pass kernel. Called once per row by `core::ops::nn::RmsNorm` + /// when the input is f32 and the axis is the last (contiguous) one. + pub rms_norm_f32: Box, } impl Ops { @@ -243,6 +249,7 @@ pub fn generic() -> Ops { */ softmax2_fastcompact_f16: Box::new(|| generic::reduce::softmax_l2::HSoftMaxL2::red()), softmax2_fastcompact_f32: Box::new(|| generic::reduce::softmax_l2::SSoftMaxL2::red()), + rms_norm_f32: Box::new(generic::rms_norm::rms_norm_f32), }; crate::generic::mmm::plug(&mut ops); ops diff --git a/linalg/src/x86_64_fma.rs b/linalg/src/x86_64_fma.rs index 733d1e7121..e61baa2efe 100644 --- a/linalg/src/x86_64_fma.rs +++ b/linalg/src/x86_64_fma.rs @@ -1,23 +1,37 @@ use crate::Ops; use crate::frame::element_wise::ElementWiseKer; use crate::frame::reduce::{MapReduceKer, ReduceKer}; +use crate::x86_64_fma::softmax::x86_64_avx512_softmax2_fastcompact_f16_64n; use crate::x86_64_fma::softmax::x86_64_fma_softmax2_fastcompact_f32_32n; pub mod mmm; +pub mod act; +pub mod act_f16; +pub mod act_f16_fp16; pub mod by_scalar; +pub mod erf; mod intel; pub mod max; pub mod panel_extract; +pub mod rms_norm; pub mod softmax; const AVX2: fn() -> bool = || is_x86_feature_detected!("avx2"); const FMA: fn() -> bool = || is_x86_feature_detected!("fma"); const AVX512F: fn() -> bool = || is_x86_feature_detected!("avx512f"); +#[cfg(tract_avx512vnni)] +const AVX512VNNI: fn() -> bool = || is_x86_feature_detected!("avx512vnni"); tanh_impl!(f32, fma_tanh_f32, 8, 8, is_x86_feature_detected!("fma")); sigmoid_impl!(f32, fma_sigmoid_f32, 8, 8, is_x86_feature_detected!("fma")); +// AVX-512 (zmm, 16-wide) variants. The assembly lives in x86_64/avx512/; the +// main loop handles 64 lanes (4 zmm) per iteration with a 16-lane tail, so +// nr()=16 (any multiple of 16 is safe). +tanh_impl!(f32, avx512_tanh_f32, 16, 16, is_x86_feature_detected!("avx512f")); +sigmoid_impl!(f32, avx512_sigmoid_f32, 16, 16, is_x86_feature_detected!("avx512f")); + fn plug_avx2(_ops: &mut Ops) {} fn plug_fma(ops: &mut Ops) { @@ -33,7 +47,57 @@ fn plug_fma(ops: &mut Ops) { log::info!("sigmoid_f32, tanh_f32: x86_64/fma activated"); } -fn plug_avx512f(_ops: &mut Ops) {} +/// On hosts that also support AVX-512_FP16 (Sapphire Rapids / Granite Rapids / +/// later, and recent Xeon-D / consumer parts), upgrade the f16 element-wise +/// kernels from the f32-roundtrip implementations in `act_f16.rs` to the +/// native f16 implementations in `act_f16_fp16.rs` where the native path is +/// actually faster on this uarch. We benched each op against its f32-roundtrip +/// equivalent on Sapphire Rapids and only plug in the ones that win: +/// +/// hardswish_f16: 8.71 → 31.6 Gelem/s (3.62× native) — plug in +/// leaky_relu_f16: 9.44 → 5.85 Gelem/s (0.62× native — regression) — keep +/// the f32-roundtrip version from act_f16.rs. The native +/// kernel exists in act_f16_fp16.rs for future revisits but +/// is not wired here. +fn plug_avx512fp16(ops: &mut Ops) { + ops.hardswish_f16 = Box::new(|| act_f16_fp16::x86_64_avx512fp16_hardswish_f16_128n::ew()); + + log::info!("hardswish_f16: x86_64/avx512fp16 native activated"); +} + +fn plug_avx512f(ops: &mut Ops) { + ops.sigmoid_f32 = Box::new(|| avx512_sigmoid_f32::ew()); + ops.tanh_f32 = Box::new(|| avx512_tanh_f32::ew()); + ops.hardswish_f32 = Box::new(|| act::x86_64_avx512_hardswish_f32_64n::ew()); + ops.leaky_relu_f32 = Box::new(|| act::x86_64_avx512_leaky_relu_f32_64n::ew()); + ops.silu_f32 = Box::new(|| act::x86_64_avx512_silu_f32_16n::ew()); + ops.gelu_f32 = Box::new(|| act::x86_64_avx512_gelu_f32_16n::ew()); + + ops.sigmoid_f16 = Box::new(|| act_f16::x86_64_avx512_sigmoid_f16_16n::ew()); + ops.tanh_f16 = Box::new(|| act_f16::x86_64_avx512_tanh_f16_16n::ew()); + ops.hardswish_f16 = Box::new(|| act_f16::x86_64_avx512_hardswish_f16_64n::ew()); + ops.leaky_relu_f16 = Box::new(|| act_f16::x86_64_avx512_leaky_relu_f16_64n::ew()); + ops.silu_f16 = Box::new(|| act_f16::x86_64_avx512_silu_f16_16n::ew()); + ops.gelu_f16 = Box::new(|| act_f16::x86_64_avx512_gelu_f16_16n::ew()); + + ops.max_f32 = Box::new(|| max::x86_64_avx512_max_f32_64n::red()); + ops.softmax2_fastcompact_f32 = + Box::new(|| softmax::x86_64_avx512_softmax2_fastcompact_f32_64n::red()); + ops.softmax2_fastcompact_f16 = Box::new(|| x86_64_avx512_softmax2_fastcompact_f16_64n::red()); + + ops.erf_f32 = Box::new(|| erf::x86_64_avx512_erf_f32_64n::ew()); + + ops.rms_norm_f32 = Box::new(rms_norm::rms_norm_f32); + + log::info!( + "sigmoid_f32, tanh_f32, hardswish_f32, leaky_relu_f32, \ + silu_f32, gelu_f32, \ + sigmoid_f16, tanh_f16, hardswish_f16, leaky_relu_f16, \ + silu_f16, gelu_f16, \ + max_f32, softmax2_fastcompact_f32, softmax2_fastcompact_f16, erf_f32, \ + rms_norm_f32: x86_64/avx512f activated" + ); +} pub fn plug(ops: &mut Ops) { mmm::plug(ops); @@ -43,6 +107,9 @@ pub fn plug(ops: &mut Ops) { plug_fma(ops); if is_x86_feature_detected!("avx512f") { plug_avx512f(ops); + if is_x86_feature_detected!("avx512fp16") { + plug_avx512fp16(ops); + } } } } diff --git a/linalg/src/x86_64_fma/act.rs b/linalg/src/x86_64_fma/act.rs new file mode 100644 index 0000000000..1acf995655 --- /dev/null +++ b/linalg/src/x86_64_fma/act.rs @@ -0,0 +1,258 @@ +// AVX-512 (zmm, 16-wide) element-wise activation kernels with no FMA +// predecessor on x86: hardswish and leaky_relu. They mirror the aarch64 NEON +// kernels (arm64simd_hardswish_f32_8n / arm64simd_leaky_relu_f32_8n) but use +// 512-bit zmm registers, processing 64 f32 lanes per iteration. Validated +// against the generic scalar reference via the *_frame_tests! macros. + +// hardswish(x) = x * relu6(x + 3) / 6 +// = x * max(0, min(6, x + 3)) * (1/6) +ew_impl_wrap!( + f32, + x86_64_avx512_hardswish_f32_64n, + 64, + 16, + (), + #[inline(never)] + fn run(buf: &mut [f32], _: ()) { + debug_assert!(buf.len() % Self::nr() == 0); + debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0); + if buf.is_empty() { + return; + } + unsafe { x86_64_avx512_hardswish_f32_64n_run(buf) } + } +); + +#[target_feature(enable = "avx512f")] +unsafe fn x86_64_avx512_hardswish_f32_64n_run(buf: &mut [f32]) { + unsafe { + let len = buf.len(); + let ptr = buf.as_ptr(); + std::arch::asm!(" + vbroadcastss zmm0, xmm0 // 3.0 + vbroadcastss zmm1, xmm1 // 6.0 + vbroadcastss zmm2, xmm2 // 1/6 + vpxord zmm3, zmm3, zmm3 // 0.0 + 2: + vmovaps zmm4, [{ptr}] + vmovaps zmm5, [{ptr} + 64] + vmovaps zmm6, [{ptr} + 128] + vmovaps zmm7, [{ptr} + 192] + + vaddps zmm8, zmm4, zmm0 + vaddps zmm9, zmm5, zmm0 + vaddps zmm10, zmm6, zmm0 + vaddps zmm11, zmm7, zmm0 + + vminps zmm8, zmm8, zmm1 + vminps zmm9, zmm9, zmm1 + vminps zmm10, zmm10, zmm1 + vminps zmm11, zmm11, zmm1 + + vmaxps zmm8, zmm8, zmm3 + vmaxps zmm9, zmm9, zmm3 + vmaxps zmm10, zmm10, zmm3 + vmaxps zmm11, zmm11, zmm3 + + vmulps zmm8, zmm8, zmm4 + vmulps zmm9, zmm9, zmm5 + vmulps zmm10, zmm10, zmm6 + vmulps zmm11, zmm11, zmm7 + + vmulps zmm8, zmm8, zmm2 + vmulps zmm9, zmm9, zmm2 + vmulps zmm10, zmm10, zmm2 + vmulps zmm11, zmm11, zmm2 + + vmovaps [{ptr}], zmm8 + vmovaps [{ptr} + 64], zmm9 + vmovaps [{ptr} + 128], zmm10 + vmovaps [{ptr} + 192], zmm11 + + add {ptr}, 256 + sub {len}, 64 + jnz 2b + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + inout("xmm0") 3.0f32 => _, + inout("xmm1") 6.0f32 => _, + inout("xmm2") 1.0f32 / 6.0f32 => _, + out("zmm3") _, + out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _, + out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _, + ); + } +} + +#[cfg(test)] +pub mod test_x86_64_avx512_hardswish_f32_64n { + use super::*; + hardswish_frame_tests!( + is_x86_feature_detected!("avx512f"), + f32, + x86_64_avx512_hardswish_f32_64n + ); +} + +// leaky_relu(x) = x > 0 ? x : alpha * x +ew_impl_wrap!( + f32, + x86_64_avx512_leaky_relu_f32_64n, + 64, + 16, + f32, + #[inline(never)] + fn run(buf: &mut [f32], alpha: f32) { + debug_assert!(buf.len() % Self::nr() == 0); + debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0); + if buf.is_empty() { + return; + } + unsafe { x86_64_avx512_leaky_relu_f32_64n_run(buf, alpha) } + } +); + +#[target_feature(enable = "avx512f")] +unsafe fn x86_64_avx512_leaky_relu_f32_64n_run(buf: &mut [f32], alpha: f32) { + unsafe { + let len = buf.len(); + let ptr = buf.as_ptr(); + std::arch::asm!(" + vbroadcastss zmm0, xmm0 // alpha + vpxord zmm1, zmm1, zmm1 // 0.0 + 2: + vmovaps zmm4, [{ptr}] + vmovaps zmm5, [{ptr} + 64] + vmovaps zmm6, [{ptr} + 128] + vmovaps zmm7, [{ptr} + 192] + + // alpha * x in zmm8..11 + vmulps zmm8, zmm4, zmm0 + vmulps zmm9, zmm5, zmm0 + vmulps zmm10, zmm6, zmm0 + vmulps zmm11, zmm7, zmm0 + + // mask = x > 0 + vcmpps k1, zmm4, zmm1, 14 + vcmpps k2, zmm5, zmm1, 14 + vcmpps k3, zmm6, zmm1, 14 + vcmpps k4, zmm7, zmm1, 14 + + // where x > 0, overwrite alpha*x with x + vmovaps zmm8{{k1}}, zmm4 + vmovaps zmm9{{k2}}, zmm5 + vmovaps zmm10{{k3}}, zmm6 + vmovaps zmm11{{k4}}, zmm7 + + vmovaps [{ptr}], zmm8 + vmovaps [{ptr} + 64], zmm9 + vmovaps [{ptr} + 128], zmm10 + vmovaps [{ptr} + 192], zmm11 + + add {ptr}, 256 + sub {len}, 64 + jnz 2b + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + inout("xmm0") alpha => _, + out("zmm1") _, + out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _, + out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _, + out("k1") _, out("k2") _, out("k3") _, out("k4") _, + ); + } +} + +#[cfg(test)] +pub mod test_x86_64_avx512_leaky_relu_f32_64n { + use super::*; + leaky_relu_frame_tests!( + is_x86_feature_detected!("avx512f"), + f32, + x86_64_avx512_leaky_relu_f32_64n + ); +} + +// SiLU(x) = x * sigmoid(x). Composed at the kernel level (mirrors arm64): save +// the input chunk, run the AVX-512 sigmoid kernel in place, then multiply back +// by the saved original. nr() and CHUNK (256) are multiples of 16 so the +// sigmoid kernel always receives a 64-byte-aligned slice whose length is a +// multiple of 16. +ew_impl_wrap!( + f32, + x86_64_avx512_silu_f32_16n, + 16, + 16, + (), + #[inline(never)] + fn run(buf: &mut [f32], _: ()) { + debug_assert!(buf.len() % Self::nr() == 0); + debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0); + const CHUNK: usize = 256; + let mut scratch = [0f32; CHUNK]; + let mut start = 0; + while start < buf.len() { + let end = (start + CHUNK).min(buf.len()); + let chunk = &mut buf[start..end]; + let n = chunk.len(); + scratch[..n].copy_from_slice(chunk); + super::avx512_sigmoid_f32::run(chunk, ()); + for i in 0..n { + chunk[i] *= scratch[i]; + } + start = end; + } + } +); + +#[cfg(test)] +pub mod test_x86_64_avx512_silu_f32_16n { + use super::*; + silu_frame_tests!(is_x86_feature_detected!("avx512f"), f32, x86_64_avx512_silu_f32_16n); +} + +// Tanh-form GELU (pow=3) matching tract's GeluApproximate: +// gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) +// Composed at the kernel level (mirrors arm64): save the original x, compute +// the tanh argument in place, run the AVX-512 tanh kernel, then finish with the +// 0.5 * x * (1 + tanh) combine. +ew_impl_wrap!( + f32, + x86_64_avx512_gelu_f32_16n, + 16, + 16, + (), + #[inline(never)] + fn run(buf: &mut [f32], _: ()) { + debug_assert!(buf.len() % Self::nr() == 0); + debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0); + const SQRT_2_OVER_PI: f32 = 0.7978845608028654; + const COEF: f32 = 0.044715; + const CHUNK: usize = 256; + let mut scratch = [0f32; CHUNK]; + let mut start = 0; + while start < buf.len() { + let end = (start + CHUNK).min(buf.len()); + let chunk = &mut buf[start..end]; + let n = chunk.len(); + for i in 0..n { + let x = chunk[i]; + scratch[i] = x; + chunk[i] = SQRT_2_OVER_PI * (x + COEF * x * x * x); + } + super::avx512_tanh_f32::run(chunk, ()); + for i in 0..n { + chunk[i] = 0.5 * scratch[i] * (1.0 + chunk[i]); + } + start = end; + } + } +); + +#[cfg(test)] +pub mod test_x86_64_avx512_gelu_f32_16n { + use super::*; + gelu_frame_tests!(is_x86_feature_detected!("avx512f"), f32, x86_64_avx512_gelu_f32_16n); +} diff --git a/linalg/src/x86_64_fma/act_f16.rs b/linalg/src/x86_64_fma/act_f16.rs new file mode 100644 index 0000000000..97c623c81b --- /dev/null +++ b/linalg/src/x86_64_fma/act_f16.rs @@ -0,0 +1,300 @@ +// AVX-512 f16 element-wise activations. Each kernel chunks f16 -> 64-byte-aligned +// f32 scratch via vcvtph2ps (cvt_f16_to_f32 below), runs the matching f32 +// AVX-512 kernel (or the avx512_sigmoid_f32 / avx512_tanh_f32 wrappers from +// x86_64_fma.rs), and converts back to f16 via vcvtps2ph (cvt_f32_to_f16). +// Conversion is driven through std::arch intrinsics directly because the +// scalar f16::to_f32 / f16::from_f32 loops are not autovectorized by +// rustc + LLVM (branches / call overhead in the half crate's methods), +// which leaves a naive port stuck around 7 Melem/s. +// +// The f32 AVX-512 activation kernels assume 64-byte aligned input (alignment +// bytes = nr * 4 for nr >= 16). The local scratch is wrapped in +// #[repr(C, align(64))] so the contained [f32; 256] sits at a 64-byte boundary. +// +// Validated against the generic f16 reference (HHardSwish8 / HLeakyRelu8 / +// HSigmoid8 / HTanh8 / HSiLU8 / HGelu8) via the existing *_frame_tests! +// macros at SuperApproximate tolerance, which covers the precision delta +// between scalar f16 arithmetic and f32-internal computation. + +use tract_data::internal::f16; + +#[repr(C, align(64))] +struct AlignedScratch([f32; 256]); + +impl AlignedScratch { + fn new() -> Self { + Self([0f32; 256]) + } +} + +const CHUNK: usize = 256; + +// Vectorized f16 <-> f32 helpers using vcvtph2ps / vcvtps2ph. Rustc + LLVM +// do NOT autovectorize the scalar `.to_f32()` loop (the half crate's method +// has branches / function-call overhead), so we drive the conversion with +// intrinsics directly. Both helpers process 16 lanes per iteration; the tail +// (which only fires for the 1-15 leftover lanes inside a CHUNK = 256 batch) +// falls back to scalar. +#[target_feature(enable = "avx512f")] +unsafe fn cvt_f16_to_f32(src: &[f16], dst: &mut [f32]) { + use core::arch::x86_64::*; + let n = src.len(); + debug_assert!(dst.len() >= n); + let chunks = n / 16; + unsafe { + for k in 0..chunks { + let m = _mm256_loadu_si256(src.as_ptr().add(k * 16) as *const __m256i); + let z = _mm512_cvtph_ps(m); + _mm512_storeu_ps(dst.as_mut_ptr().add(k * 16), z); + } + for k in (chunks * 16)..n { + *dst.get_unchecked_mut(k) = src.get_unchecked(k).to_f32(); + } + } +} + +#[target_feature(enable = "avx512f")] +unsafe fn cvt_f32_to_f16(src: &[f32], dst: &mut [f16]) { + use core::arch::x86_64::*; + let n = src.len(); + debug_assert!(dst.len() >= n); + let chunks = n / 16; + unsafe { + for k in 0..chunks { + let z = _mm512_loadu_ps(src.as_ptr().add(k * 16)); + // _MM_FROUND_TO_NEAREST_INT == 0 (round-to-nearest-even, matches f16::from_f32) + let m = _mm512_cvtps_ph::<0>(z); + _mm256_storeu_si256(dst.as_mut_ptr().add(k * 16) as *mut __m256i, m); + } + for k in (chunks * 16)..n { + *dst.get_unchecked_mut(k) = f16::from_f32(*src.get_unchecked(k)); + } + } +} + +// hardswish_f16 +ew_impl_wrap!( + f16, + x86_64_avx512_hardswish_f16_64n, + 64, + 32, + (), + #[inline(never)] + fn run(buf: &mut [f16], _: ()) { + debug_assert!(buf.len() % Self::nr() == 0); + debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0); + if buf.is_empty() { + return; + } + let mut scratch = AlignedScratch::new(); + let s = &mut scratch.0; + let mut i = 0; + while i < buf.len() { + let n = (CHUNK).min(buf.len() - i); + unsafe { cvt_f16_to_f32(&buf[i..i + n], &mut s[..n]) }; + super::act::x86_64_avx512_hardswish_f32_64n::run(&mut s[..n], ()); + unsafe { cvt_f32_to_f16(&s[..n], &mut buf[i..i + n]) }; + i += n; + } + } +); + +#[cfg(test)] +pub mod test_x86_64_avx512_hardswish_f16_64n { + use super::*; + hardswish_frame_tests!( + is_x86_feature_detected!("avx512f"), + f16, + x86_64_avx512_hardswish_f16_64n + ); +} + +// leaky_relu_f16 (parameter: alpha as f16) +ew_impl_wrap!( + f16, + x86_64_avx512_leaky_relu_f16_64n, + 64, + 32, + f16, + #[inline(never)] + fn run(buf: &mut [f16], alpha: f16) { + debug_assert!(buf.len() % Self::nr() == 0); + debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0); + if buf.is_empty() { + return; + } + let alpha_f32 = alpha.to_f32(); + let mut scratch = AlignedScratch::new(); + let s = &mut scratch.0; + let mut i = 0; + while i < buf.len() { + let n = (CHUNK).min(buf.len() - i); + unsafe { cvt_f16_to_f32(&buf[i..i + n], &mut s[..n]) }; + super::act::x86_64_avx512_leaky_relu_f32_64n::run(&mut s[..n], alpha_f32); + unsafe { cvt_f32_to_f16(&s[..n], &mut buf[i..i + n]) }; + i += n; + } + } +); + +#[cfg(test)] +pub mod test_x86_64_avx512_leaky_relu_f16_64n { + use super::*; + leaky_relu_frame_tests!( + is_x86_feature_detected!("avx512f"), + f16, + x86_64_avx512_leaky_relu_f16_64n + ); +} + +// sigmoid_f16 (calls the avx512_sigmoid_f32 wrapper from x86_64_fma.rs; +// its nr() is 16 so CHUNK=256 is always a clean multiple) +ew_impl_wrap!( + f16, + x86_64_avx512_sigmoid_f16_16n, + 16, + 16, + (), + #[inline(never)] + fn run(buf: &mut [f16], _: ()) { + debug_assert!(buf.len() % Self::nr() == 0); + debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0); + if buf.is_empty() { + return; + } + let mut scratch = AlignedScratch::new(); + let s = &mut scratch.0; + let mut i = 0; + while i < buf.len() { + let n = (CHUNK).min(buf.len() - i); + unsafe { cvt_f16_to_f32(&buf[i..i + n], &mut s[..n]) }; + super::avx512_sigmoid_f32::run(&mut s[..n], ()); + unsafe { cvt_f32_to_f16(&s[..n], &mut buf[i..i + n]) }; + i += n; + } + } +); + +#[cfg(test)] +pub mod test_x86_64_avx512_sigmoid_f16_16n { + use super::*; + sigmoid_frame_tests!(is_x86_feature_detected!("avx512f"), f16, x86_64_avx512_sigmoid_f16_16n); +} + +// tanh_f16 +ew_impl_wrap!( + f16, + x86_64_avx512_tanh_f16_16n, + 16, + 16, + (), + #[inline(never)] + fn run(buf: &mut [f16], _: ()) { + debug_assert!(buf.len() % Self::nr() == 0); + debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0); + if buf.is_empty() { + return; + } + let mut scratch = AlignedScratch::new(); + let s = &mut scratch.0; + let mut i = 0; + while i < buf.len() { + let n = (CHUNK).min(buf.len() - i); + unsafe { cvt_f16_to_f32(&buf[i..i + n], &mut s[..n]) }; + super::avx512_tanh_f32::run(&mut s[..n], ()); + unsafe { cvt_f32_to_f16(&s[..n], &mut buf[i..i + n]) }; + i += n; + } + } +); + +#[cfg(test)] +pub mod test_x86_64_avx512_tanh_f16_16n { + use super::*; + tanh_frame_tests!(is_x86_feature_detected!("avx512f"), f16, x86_64_avx512_tanh_f16_16n); +} + +// silu_f16: x * sigmoid(x). Mirror the f32 silu pattern: save the input +// (in f32), run sigmoid in place on the scratch, then multiply back. +ew_impl_wrap!( + f16, + x86_64_avx512_silu_f16_16n, + 16, + 16, + (), + #[inline(never)] + fn run(buf: &mut [f16], _: ()) { + debug_assert!(buf.len() % Self::nr() == 0); + debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0); + if buf.is_empty() { + return; + } + let mut work = AlignedScratch::new(); + let mut save = AlignedScratch::new(); + let w = &mut work.0; + let v = &mut save.0; + let mut i = 0; + while i < buf.len() { + let n = (CHUNK).min(buf.len() - i); + unsafe { cvt_f16_to_f32(&buf[i..i + n], &mut w[..n]) }; + v[..n].copy_from_slice(&w[..n]); + super::avx512_sigmoid_f32::run(&mut w[..n], ()); + for j in 0..n { + w[j] *= v[j]; + } + unsafe { cvt_f32_to_f16(&w[..n], &mut buf[i..i + n]) }; + i += n; + } + } +); + +#[cfg(test)] +pub mod test_x86_64_avx512_silu_f16_16n { + use super::*; + silu_frame_tests!(is_x86_feature_detected!("avx512f"), f16, x86_64_avx512_silu_f16_16n); +} + +// Tanh-form GELU (matches tract's GeluApproximate, pow=3, see act.rs gelu_f32): +// gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) +ew_impl_wrap!( + f16, + x86_64_avx512_gelu_f16_16n, + 16, + 16, + (), + #[inline(never)] + fn run(buf: &mut [f16], _: ()) { + debug_assert!(buf.len() % Self::nr() == 0); + debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0); + if buf.is_empty() { + return; + } + const SQRT_2_OVER_PI: f32 = 0.7978845608028654; + const COEF: f32 = 0.044715; + let mut work = AlignedScratch::new(); + let mut save = AlignedScratch::new(); + let w = &mut work.0; + let v = &mut save.0; + let mut i = 0; + while i < buf.len() { + let n = (CHUNK).min(buf.len() - i); + unsafe { cvt_f16_to_f32(&buf[i..i + n], &mut v[..n]) }; + for j in 0..n { + let x = v[j]; + w[j] = SQRT_2_OVER_PI * (x + COEF * x * x * x); + } + super::avx512_tanh_f32::run(&mut w[..n], ()); + for j in 0..n { + w[j] = 0.5 * v[j] * (1.0 + w[j]); + } + unsafe { cvt_f32_to_f16(&w[..n], &mut buf[i..i + n]) }; + i += n; + } + } +); + +#[cfg(test)] +pub mod test_x86_64_avx512_gelu_f16_16n { + use super::*; + gelu_frame_tests!(is_x86_feature_detected!("avx512f"), f16, x86_64_avx512_gelu_f16_16n); +} diff --git a/linalg/src/x86_64_fma/act_f16_fp16.rs b/linalg/src/x86_64_fma/act_f16_fp16.rs new file mode 100644 index 0000000000..4a40002458 --- /dev/null +++ b/linalg/src/x86_64_fma/act_f16_fp16.rs @@ -0,0 +1,203 @@ +// AVX-512_FP16 native f16 element-wise activations. +// +// Sapphire Rapids (and later Intel) added the AVX-512 FP16 ISA: zmm-wide +// arithmetic on f16 directly (`vmulph`, `vfmadd*ph`, `vmaxph`, `vminph`, +// `vaddph`, `vsubph`, etc.). 32 f16 lanes per zmm — double the parallelism of +// the f32-roundtrip kernels in `act_f16.rs`, and zero conversion at the IO +// boundary. +// +// The kernels here mirror the algorithm of the f32 versions in `act.rs` and +// the f32-roundtrip f16 versions in `act_f16.rs`. Polynomials are evaluated +// directly in f16, accepting the lower mantissa precision (11 bits vs f32's +// 24) — the resulting tolerance fits inside the f16 activation tests' +// SuperApproximate band. +// +// Gated on `is_x86_feature_detected!("avx512fp16")` (the actual gating happens +// in `plug_avx512fp16` over in `x86_64_fma.rs`). Pre-FP16 AVX-512 hosts +// (Skylake-X, Cascade Lake, Ice Lake server prior to fp16 extension) keep +// using `act_f16.rs`'s f32-roundtrip versions. + +use tract_data::internal::f16; + +const FP16_TARGETS: &str = "avx512f,avx512fp16,avx512bw"; + +// hardswish(x) = x * clamp(x + 3, 0, 6) * (1/6). +// 128 f16 per iter (4 zmm × 32 lanes), 256 bytes / iter — same memory throughput +// as the f32 kernel's 64 f32 / iter. +ew_impl_wrap!( + f16, + x86_64_avx512fp16_hardswish_f16_128n, + 128, + 32, + (), + #[inline(never)] + fn run(buf: &mut [f16], _: ()) { + debug_assert!(buf.len() % Self::nr() == 0); + debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0); + if buf.is_empty() { + return; + } + unsafe { hardswish_f16_run(buf) } + } +); + +#[target_feature(enable = "avx512f,avx512fp16,avx512bw")] +unsafe fn hardswish_f16_run(buf: &mut [f16]) { + let len = buf.len(); + let ptr = buf.as_ptr() as *mut u8; + let three = f16::from_f32(3.0).to_bits(); + let six = f16::from_f32(6.0).to_bits(); + let recip6 = f16::from_f32(1.0 / 6.0).to_bits(); + unsafe { + std::arch::asm!(" + vpbroadcastw zmm0, eax // 3.0 + vpbroadcastw zmm1, ecx // 6.0 + vpbroadcastw zmm2, edx // 1/6 + vpxord zmm3, zmm3, zmm3 // 0.0 + 2: + vmovdqa64 zmm4, [{ptr}] + vmovdqa64 zmm5, [{ptr} + 64] + vmovdqa64 zmm6, [{ptr} + 128] + vmovdqa64 zmm7, [{ptr} + 192] + + vaddph zmm8, zmm4, zmm0 + vaddph zmm9, zmm5, zmm0 + vaddph zmm10, zmm6, zmm0 + vaddph zmm11, zmm7, zmm0 + + vminph zmm8, zmm8, zmm1 + vminph zmm9, zmm9, zmm1 + vminph zmm10, zmm10, zmm1 + vminph zmm11, zmm11, zmm1 + + vmaxph zmm8, zmm8, zmm3 + vmaxph zmm9, zmm9, zmm3 + vmaxph zmm10, zmm10, zmm3 + vmaxph zmm11, zmm11, zmm3 + + vmulph zmm8, zmm8, zmm4 + vmulph zmm9, zmm9, zmm5 + vmulph zmm10, zmm10, zmm6 + vmulph zmm11, zmm11, zmm7 + + vmulph zmm8, zmm8, zmm2 + vmulph zmm9, zmm9, zmm2 + vmulph zmm10, zmm10, zmm2 + vmulph zmm11, zmm11, zmm2 + + vmovdqa64 [{ptr}], zmm8 + vmovdqa64 [{ptr} + 64], zmm9 + vmovdqa64 [{ptr} + 128], zmm10 + vmovdqa64 [{ptr} + 192], zmm11 + + add {ptr}, 256 + sub {len}, 128 + jnz 2b + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + in("eax") three as u32, + in("ecx") six as u32, + in("edx") recip6 as u32, + out("zmm0") _, out("zmm1") _, out("zmm2") _, out("zmm3") _, + out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _, + out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _, + ); + } +} + +// leaky_relu(x, alpha) = x if x >= 0 else alpha*x +// For 0 <= alpha <= 1: leaky_relu(x, alpha) = max(x, alpha*x). For the typical +// alpha values used (0.01, 0.1, 0.2) this is exact. +// +// NOTE: This native fp16 version benched ~38% SLOWER than the f32-roundtrip +// version on Sapphire Rapids (9.44 Gelem/s f32-roundtrip vs 5.85 Gelem/s +// native, n=1024, single-thread). The two compute ops per element (vmulph + +// vmaxph) appear not to saturate Sapphire Rapids' FP16 execution port the +// same way f32 mul/max saturate the FP32 ports. The kernel is correct (passes +// proptest against the f16 reference) but is NOT plugged in — see the +// `plug_avx512fp16` comment in `x86_64_fma.rs`. Kept here in case a different +// AVX-512_FP16 uarch (Granite Rapids etc.) flips the comparison. +ew_impl_wrap!( + f16, + x86_64_avx512fp16_leaky_relu_f16_128n, + 128, + 32, + f16, + #[inline(never)] + fn run(buf: &mut [f16], alpha: f16) { + debug_assert!(buf.len() % Self::nr() == 0); + debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0); + if buf.is_empty() { + return; + } + unsafe { leaky_relu_f16_run(buf, alpha) } + } +); + +#[target_feature(enable = "avx512f,avx512fp16,avx512bw")] +unsafe fn leaky_relu_f16_run(buf: &mut [f16], alpha: f16) { + let len = buf.len(); + let ptr = buf.as_ptr() as *mut u8; + let alpha_bits = alpha.to_bits(); + unsafe { + std::arch::asm!(" + vpbroadcastw zmm0, eax // alpha + 2: + vmovdqa64 zmm4, [{ptr}] + vmovdqa64 zmm5, [{ptr} + 64] + vmovdqa64 zmm6, [{ptr} + 128] + vmovdqa64 zmm7, [{ptr} + 192] + + vmulph zmm8, zmm4, zmm0 + vmulph zmm9, zmm5, zmm0 + vmulph zmm10, zmm6, zmm0 + vmulph zmm11, zmm7, zmm0 + + vmaxph zmm8, zmm8, zmm4 + vmaxph zmm9, zmm9, zmm5 + vmaxph zmm10, zmm10, zmm6 + vmaxph zmm11, zmm11, zmm7 + + vmovdqa64 [{ptr}], zmm8 + vmovdqa64 [{ptr} + 64], zmm9 + vmovdqa64 [{ptr} + 128], zmm10 + vmovdqa64 [{ptr} + 192], zmm11 + + add {ptr}, 256 + sub {len}, 128 + jnz 2b + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + in("eax") alpha_bits as u32, + out("zmm0") _, + out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _, + out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _, + ); + } +} + +#[cfg(test)] +pub mod test_x86_64_avx512fp16_hardswish { + use super::*; + crate::hardswish_frame_tests!( + is_x86_feature_detected!("avx512fp16"), + f16, + x86_64_avx512fp16_hardswish_f16_128n + ); +} + +#[cfg(test)] +pub mod test_x86_64_avx512fp16_leaky_relu { + use super::*; + crate::leaky_relu_frame_tests!( + is_x86_feature_detected!("avx512fp16"), + f16, + x86_64_avx512fp16_leaky_relu_f16_128n + ); +} + +// Suppress unused-const lint until we expand to more kernels. +#[allow(dead_code)] +const _UNUSED: &str = FP16_TARGETS; diff --git a/linalg/src/x86_64_fma/erf.rs b/linalg/src/x86_64_fma/erf.rs new file mode 100644 index 0000000000..c34b207c7c --- /dev/null +++ b/linalg/src/x86_64_fma/erf.rs @@ -0,0 +1,204 @@ +// AVX-512 (zmm, 16-wide) error function kernel. Mirrors generic/erf.rs::serf +// (Abramowitz & Stegun 7.1.26 six-coefficient polynomial) but runs the +// polynomial via FMA chains over 4 zmm registers per iteration (64 lanes per +// loop step). Validated against the generic scalar reference via +// erf_frame_tests! at SuperApproximate tolerance. +// +// Algorithm (per lane): +// signum = sign(x); abs = |x| +// y = a6 +// y = y*abs + a5 (Horner FMA) +// y = y*abs + a4 +// y = y*abs + a3 +// y = y*abs + a2 +// y = y*abs + a1 +// y = y * abs (final factor of abs) +// y = y + 1 +// y = y^16 (4 sequential squares) +// y = 1 / y (vdivps, full IEEE precision) +// y = 1 - y +// result = copysign(y, x) + +ew_impl_wrap!( + f32, + x86_64_avx512_erf_f32_64n, + 64, + 16, + (), + #[inline(never)] + fn run(buf: &mut [f32], _: ()) { + debug_assert!(buf.len() % Self::nr() == 0); + debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0); + if buf.is_empty() { + return; + } + unsafe { x86_64_avx512_erf_f32_64n_run(buf) } + } +); + +#[target_feature(enable = "avx512f")] +unsafe fn x86_64_avx512_erf_f32_64n_run(buf: &mut [f32]) { + unsafe { + let len = buf.len(); + let ptr = buf.as_ptr(); + const A1: f32 = 0.0705230784; + const A2: f32 = 0.0422820123; + const A3: f32 = 0.0092705272; + const A4: f32 = 0.0001520143; + const A5: f32 = 0.0002765672; + const A6: f32 = 0.0000430638; + // 0x7fffffff: positive-finite mask (clears sign bit). As f32 bits, this + // is NaN; we never use it as a numeric value — only as a bit mask via vandps. + const ABS_MASK: f32 = f32::from_bits(0x7fffffff); + const SIGN_MASK: f32 = f32::from_bits(0x80000000); + std::arch::asm!(" + // broadcast constants (xmmN -> zmmN, broadcast across all 16 lanes) + vbroadcastss zmm0, xmm0 // a1 + vbroadcastss zmm1, xmm1 // a2 + vbroadcastss zmm2, xmm2 // a3 + vbroadcastss zmm3, xmm3 // a4 + vbroadcastss zmm4, xmm4 // a5 + vbroadcastss zmm5, xmm5 // a6 + vbroadcastss zmm6, xmm6 // 1.0 + vbroadcastss zmm7, xmm7 // abs mask (0x7fffffff) + vbroadcastss zmm8, xmm8 // sign mask (0x80000000) + 2: + // load 4 zmm of input + vmovaps zmm9, [{ptr}] + vmovaps zmm10, [{ptr} + 64] + vmovaps zmm11, [{ptr} + 128] + vmovaps zmm12, [{ptr} + 192] + + // sign[i] = x[i] & SIGN_MASK (keeps only the sign bit) + vandps zmm13, zmm9, zmm8 + vandps zmm14, zmm10, zmm8 + vandps zmm15, zmm11, zmm8 + vandps zmm16, zmm12, zmm8 + + // abs[i] = x[i] & ABS_MASK (clears the sign bit) + vandps zmm9, zmm9, zmm7 + vandps zmm10, zmm10, zmm7 + vandps zmm11, zmm11, zmm7 + vandps zmm12, zmm12, zmm7 + + // y = a6 (in zmm17..20, 4 independent channels) + vmovaps zmm17, zmm5 + vmovaps zmm18, zmm5 + vmovaps zmm19, zmm5 + vmovaps zmm20, zmm5 + + // y = y*abs + a5 + vfmadd213ps zmm17, zmm9, zmm4 + vfmadd213ps zmm18, zmm10, zmm4 + vfmadd213ps zmm19, zmm11, zmm4 + vfmadd213ps zmm20, zmm12, zmm4 + + // y = y*abs + a4 + vfmadd213ps zmm17, zmm9, zmm3 + vfmadd213ps zmm18, zmm10, zmm3 + vfmadd213ps zmm19, zmm11, zmm3 + vfmadd213ps zmm20, zmm12, zmm3 + + // y = y*abs + a3 + vfmadd213ps zmm17, zmm9, zmm2 + vfmadd213ps zmm18, zmm10, zmm2 + vfmadd213ps zmm19, zmm11, zmm2 + vfmadd213ps zmm20, zmm12, zmm2 + + // y = y*abs + a2 + vfmadd213ps zmm17, zmm9, zmm1 + vfmadd213ps zmm18, zmm10, zmm1 + vfmadd213ps zmm19, zmm11, zmm1 + vfmadd213ps zmm20, zmm12, zmm1 + + // y = y*abs + a1 + vfmadd213ps zmm17, zmm9, zmm0 + vfmadd213ps zmm18, zmm10, zmm0 + vfmadd213ps zmm19, zmm11, zmm0 + vfmadd213ps zmm20, zmm12, zmm0 + + // y = y * abs (final factor) + vmulps zmm17, zmm17, zmm9 + vmulps zmm18, zmm18, zmm10 + vmulps zmm19, zmm19, zmm11 + vmulps zmm20, zmm20, zmm12 + + // y = y + 1 + vaddps zmm17, zmm17, zmm6 + vaddps zmm18, zmm18, zmm6 + vaddps zmm19, zmm19, zmm6 + vaddps zmm20, zmm20, zmm6 + + // y^16: square 4 times + vmulps zmm17, zmm17, zmm17 + vmulps zmm18, zmm18, zmm18 + vmulps zmm19, zmm19, zmm19 + vmulps zmm20, zmm20, zmm20 + + vmulps zmm17, zmm17, zmm17 + vmulps zmm18, zmm18, zmm18 + vmulps zmm19, zmm19, zmm19 + vmulps zmm20, zmm20, zmm20 + + vmulps zmm17, zmm17, zmm17 + vmulps zmm18, zmm18, zmm18 + vmulps zmm19, zmm19, zmm19 + vmulps zmm20, zmm20, zmm20 + + vmulps zmm17, zmm17, zmm17 + vmulps zmm18, zmm18, zmm18 + vmulps zmm19, zmm19, zmm19 + vmulps zmm20, zmm20, zmm20 + + // y = 1 / y (full-precision reciprocal, matches generic .recip()) + vdivps zmm21, zmm6, zmm17 + vdivps zmm22, zmm6, zmm18 + vdivps zmm23, zmm6, zmm19 + vdivps zmm24, zmm6, zmm20 + + // y = 1 - y + vsubps zmm21, zmm6, zmm21 + vsubps zmm22, zmm6, zmm22 + vsubps zmm23, zmm6, zmm23 + vsubps zmm24, zmm6, zmm24 + + // copysign: stamp the original sign bit onto the (positive) result + vorps zmm21, zmm21, zmm13 + vorps zmm22, zmm22, zmm14 + vorps zmm23, zmm23, zmm15 + vorps zmm24, zmm24, zmm16 + + // store + vmovaps [{ptr}], zmm21 + vmovaps [{ptr} + 64], zmm22 + vmovaps [{ptr} + 128], zmm23 + vmovaps [{ptr} + 192], zmm24 + + add {ptr}, 256 + sub {len}, 64 + jnz 2b + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + inout("xmm0") A1 => _, + inout("xmm1") A2 => _, + inout("xmm2") A3 => _, + inout("xmm3") A4 => _, + inout("xmm4") A5 => _, + inout("xmm5") A6 => _, + inout("xmm6") 1f32 => _, + inout("xmm7") ABS_MASK => _, + inout("xmm8") SIGN_MASK => _, + out("zmm9") _, out("zmm10") _, out("zmm11") _, out("zmm12") _, + out("zmm13") _, out("zmm14") _, out("zmm15") _, out("zmm16") _, + out("zmm17") _, out("zmm18") _, out("zmm19") _, out("zmm20") _, + out("zmm21") _, out("zmm22") _, out("zmm23") _, out("zmm24") _, + ); + } +} + +#[cfg(test)] +pub mod test_x86_64_avx512_erf_f32_64n { + use super::*; + crate::erf_frame_tests!(is_x86_feature_detected!("avx512f"), f32, x86_64_avx512_erf_f32_64n); +} diff --git a/linalg/src/x86_64_fma/max.rs b/linalg/src/x86_64_fma/max.rs index cea5710477..6bbc99077e 100644 --- a/linalg/src/x86_64_fma/max.rs +++ b/linalg/src/x86_64_fma/max.rs @@ -65,3 +65,72 @@ mod test_x86_64_fma_max_f32_32n { use super::*; crate::max_frame_tests!(is_x86_feature_detected!("avx2"), f32, x86_64_fma_max_f32_32n); } + +// AVX-512 version: processes 64 f32 per loop iteration (4 zmm registers of 16 +// lanes each). Runtime-gated on avx512f (see x86_64_fma.rs::plug_avx512f); on +// non-AVX512 CPUs this kernel is never registered and the FMA path above stays +// in use. nr=64, 64-byte (16xf32) alignment. +reduce_impl_wrap!( + f32, + x86_64_avx512_max_f32_64n, + 64, + 16, + (), + f32::MIN, + #[inline(never)] + fn run(buf: &[f32], _: ()) -> f32 { + assert!(buf.len() % 64 == 0); + assert!(buf.len() > 0); + unsafe { x86_64_avx512_max_f32_64n_run(buf) } + }, + #[inline(never)] + fn reduce_two(a: f32, b: f32) -> f32 { + a.max(b) + } +); + +#[target_feature(enable = "avx512f")] +unsafe fn x86_64_avx512_max_f32_64n_run(buf: &[f32]) -> f32 { + unsafe { + let len = buf.len(); + let ptr = buf.as_ptr(); + let mut acc = f32::MIN; + std::arch::asm!(" + vbroadcastss zmm0, xmm0 + vmovaps zmm1, zmm0 + vmovaps zmm2, zmm0 + vmovaps zmm3, zmm0 + 2: + vmaxps zmm0, zmm0, [{ptr}] + vmaxps zmm1, zmm1, [{ptr} + 64] + vmaxps zmm2, zmm2, [{ptr} + 128] + vmaxps zmm3, zmm3, [{ptr} + 192] + add {ptr}, 256 + sub {len}, 64 + jnz 2b + vmaxps zmm0, zmm0, zmm1 + vmaxps zmm2, zmm2, zmm3 + vmaxps zmm0, zmm0, zmm2 // zmm0 holds 16 partial maxima + vextractf64x4 ymm1, zmm0, 1 // upper 256 bits (8xf32) of zmm0 -> ymm1 (avx512f) + vmaxps ymm0, ymm0, ymm1 // ymm0 holds 8 values + vextractf128 xmm1, ymm0, 1 // upper 4xf32 -> xmm1 + vmaxps xmm0, xmm0, xmm1 // xmm0 holds 4 values + vpermilps xmm1, xmm0, 2 + (3 << 2) // second 2x32 bit half moved to top + vmaxps xmm0, xmm0, xmm1 // xmm0 holds 2 values + vpermilps xmm1, xmm0, 1 // second f32 to top + vmaxps xmm0, xmm0, xmm1 + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + inout("zmm0") acc, + out("zmm1") _, out("zmm2") _, out("zmm3") _, + ); + acc + } +} + +#[cfg(test)] +mod test_x86_64_avx512_max_f32_64n { + use super::*; + crate::max_frame_tests!(is_x86_feature_detected!("avx512f"), f32, x86_64_avx512_max_f32_64n); +} diff --git a/linalg/src/x86_64_fma/mmm.rs b/linalg/src/x86_64_fma/mmm.rs index 7b094ac2e9..eaaf40b9d0 100644 --- a/linalg/src/x86_64_fma/mmm.rs +++ b/linalg/src/x86_64_fma/mmm.rs @@ -2,7 +2,7 @@ use crate::Ops; use crate::block_quant::*; use crate::mmm::ImplementationQuality::ManuallyOptimized; use crate::mmm::MatMatMul; -use crate::pack::PackedFormat; +use crate::pack::{PackedFormat, PackedI8K4}; use super::*; @@ -91,6 +91,22 @@ MMMExternKernel! { avx2_mmm_i32_8x8(8,8)@(256,4) where(AVX2) store(i8) } +// AVX-512 VNNI int8 GEMM: same 8x8 column-accumulator tile and quantization +// epilogue as avx2_mmm_i32_8x8, but the i8i8 matmul inner loop uses VPDPBUSD +// (4-way K dot) over the K=4-inner PackedI8K4 layout. VPDPBUSD is u8*s8, so the +// kernel offsets A by +128 and removes the 128*sum_k(B) bias per column before +// the epilogue, making the i32 accumulators bit-identical to the AVX2 path. +// +// Gated on `tract_avx512vnni` (set by build.rs when the assembler can encode +// `vpdpbusd ymm`; binutils < 2.30 cannot). On old toolchains the kernel is +// omitted entirely and the AVX2 i32 path is used instead. +#[cfg(tract_avx512vnni)] +MMMExternKernel! { avx512vnni_mmm_i32_8x8(8,8)@(256,4) where(AVX512VNNI) + packing[1] = i8i8 => |k| k.with_packing(PackedI8K4::new(8), PackedI8K4::new(8)); + quality(ManuallyOptimized) + store(i8) +} + pub fn plug(ops: &mut Ops) { if is_x86_feature_detected!("avx2") { plug_avx2(ops); @@ -98,11 +114,22 @@ pub fn plug(ops: &mut Ops) { plug_fma(ops); if is_x86_feature_detected!("avx512f") { plug_avx512f(ops); + #[cfg(tract_avx512vnni)] + if is_x86_feature_detected!("avx512vnni") { + plug_avx512vnni(ops); + } } } } } +#[cfg(tract_avx512vnni)] +pub fn plug_avx512vnni(ops: &mut Ops) { + ops.mmm_impls.push(avx512vnni_mmm_i32_8x8.mmm()); + ops.qmmm_i32 = Box::new(|_, _, _| avx512vnni_mmm_i32_8x8.mmm()); + log::info!("qmmm_i32: x86_64/avx512vnni activated"); +} + pub fn plug_avx2(ops: &mut Ops) { ops.mmm_impls.push(mmm::avx2_mmm_i32_8x8.mmm()); ops.qmmm_i32 = Box::new(|_, _, _| mmm::avx2_mmm_i32_8x8.mmm()); diff --git a/linalg/src/x86_64_fma/rms_norm.rs b/linalg/src/x86_64_fma/rms_norm.rs new file mode 100644 index 0000000000..468aae101b --- /dev/null +++ b/linalg/src/x86_64_fma/rms_norm.rs @@ -0,0 +1,192 @@ +// Fused AVX-512 RmsNorm (single contiguous row): +// out[i] = x[i] * rsqrt(mean(x[i]²) + eps) +// +// Two passes over the row: +// Pass 1 (sum of squares): acc += x² over 4 zmm accumulators, then reduce +// to a scalar f32; scalar tail handles the +// (len % 64) remainder. +// Pass 2 (multiply-back): broadcast inv_std into zmm0, multiply each +// 4-zmm chunk in place; scalar tail. +// +// Uses unaligned loads/stores (vmovups) — the caller hands per-row slices of a +// possibly-misaligned tensor, so we can't assume 64-byte alignment. On Cascade +// Lake the unaligned penalty is negligible for cache-resident rows. + +#[target_feature(enable = "avx512f")] +unsafe fn rms_norm_f32_inner(buf: &mut [f32], eps: f32) { + let n = buf.len(); + let chunks = n / 64; + let tail_start = chunks * 64; + let ptr = buf.as_mut_ptr(); + + // --- Pass 1: sum of squares --- + let mut sum_sq: f32 = 0.0; + if chunks > 0 { + let p = ptr; + let c = chunks; + unsafe { + std::arch::asm!(" + vpxord zmm0, zmm0, zmm0 + vpxord zmm1, zmm1, zmm1 + vpxord zmm2, zmm2, zmm2 + vpxord zmm3, zmm3, zmm3 + 2: + vmovups zmm4, [{p}] + vmovups zmm5, [{p} + 64] + vmovups zmm6, [{p} + 128] + vmovups zmm7, [{p} + 192] + vfmadd231ps zmm0, zmm4, zmm4 + vfmadd231ps zmm1, zmm5, zmm5 + vfmadd231ps zmm2, zmm6, zmm6 + vfmadd231ps zmm3, zmm7, zmm7 + add {p}, 256 + sub {c}, 1 + jnz 2b + + vaddps zmm0, zmm0, zmm1 + vaddps zmm2, zmm2, zmm3 + vaddps zmm0, zmm0, zmm2 + vextractf64x4 ymm1, zmm0, 1 + vaddps ymm0, ymm0, ymm1 + vextractf128 xmm1, ymm0, 1 + vaddps xmm0, xmm0, xmm1 + vpermilps xmm1, xmm0, 2 + (3 << 2) + vaddps xmm0, xmm0, xmm1 + vpermilps xmm1, xmm0, 1 + vaddps xmm0, xmm0, xmm1 + ", + p = inout(reg) p => _, + c = inout(reg) c => _, + out("xmm0") sum_sq, + out("zmm1") _, out("zmm2") _, out("zmm3") _, + out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _, + ); + } + } + // scalar tail + for i in tail_start..n { + let x = unsafe { *buf.get_unchecked(i) }; + sum_sq += x * x; + } + + // --- Compute inv_std (scalar) --- + let mean_sq = sum_sq / (n as f32); + let inv_std = (mean_sq + eps).sqrt().recip(); + + // --- Pass 2: multiply by inv_std --- + if chunks > 0 { + let p = ptr; + let c = chunks; + let inv = inv_std; + unsafe { + std::arch::asm!(" + vbroadcastss zmm0, xmm0 + 2: + vmovups zmm1, [{p}] + vmovups zmm2, [{p} + 64] + vmovups zmm3, [{p} + 128] + vmovups zmm4, [{p} + 192] + vmulps zmm1, zmm1, zmm0 + vmulps zmm2, zmm2, zmm0 + vmulps zmm3, zmm3, zmm0 + vmulps zmm4, zmm4, zmm0 + vmovups [{p}], zmm1 + vmovups [{p} + 64], zmm2 + vmovups [{p} + 128], zmm3 + vmovups [{p} + 192], zmm4 + add {p}, 256 + sub {c}, 1 + jnz 2b + ", + p = inout(reg) p => _, + c = inout(reg) c => _, + inout("xmm0") inv => _, + out("zmm1") _, out("zmm2") _, out("zmm3") _, out("zmm4") _, + ); + } + } + for i in tail_start..n { + unsafe { + *buf.get_unchecked_mut(i) *= inv_std; + } + } +} + +pub fn rms_norm_f32(buf: &mut [f32], eps: f32) { + if buf.is_empty() { + return; + } + unsafe { rms_norm_f32_inner(buf, eps) } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn ref_rms_norm(buf: &mut [f32], eps: f32) { + let n = buf.len() as f32; + let sum_sq: f32 = buf.iter().map(|x| x * x).sum(); + let mean_sq = sum_sq / n; + let inv_std = (mean_sq + eps).sqrt().recip(); + for x in buf.iter_mut() { + *x *= inv_std; + } + } + + fn close_enough(got: &[f32], want: &[f32], tol: f32) { + assert_eq!(got.len(), want.len()); + for (i, (g, w)) in got.iter().zip(want.iter()).enumerate() { + let diff = (g - w).abs(); + assert!(diff <= tol, "lane {i}: got {g}, want {w}, diff {diff}"); + } + } + + #[test] + fn matches_reference_64() { + if !std::is_x86_feature_detected!("avx512f") { + return; + } + let mut x: Vec = (0..64).map(|i| (i as f32 * 0.13).sin() * 5.0).collect(); + let mut y = x.clone(); + rms_norm_f32(&mut x, 1e-5); + ref_rms_norm(&mut y, 1e-5); + close_enough(&x, &y, 1e-5); + } + + #[test] + fn matches_reference_1024_with_tail() { + if !std::is_x86_feature_detected!("avx512f") { + return; + } + // 1024 + 17 = a row that exercises the scalar tail loop. + let n = 1024 + 17; + let mut x: Vec = (0..n).map(|i| (i as f32 * 0.07).cos() * 3.0).collect(); + let mut y = x.clone(); + rms_norm_f32(&mut x, 1e-5); + ref_rms_norm(&mut y, 1e-5); + close_enough(&x, &y, 1e-4); + } + + #[test] + fn matches_reference_short_below_chunk() { + if !std::is_x86_feature_detected!("avx512f") { + return; + } + // Shorter than 64 -> all scalar tail. + let mut x: Vec = vec![0.5, -1.5, 2.5, -3.5, 0.0, 4.0, -4.0, 1.0]; + let mut y = x.clone(); + rms_norm_f32(&mut x, 1e-5); + ref_rms_norm(&mut y, 1e-5); + close_enough(&x, &y, 1e-5); + } + + #[test] + fn empty_is_noop() { + if !std::is_x86_feature_detected!("avx512f") { + return; + } + let mut x: Vec = vec![]; + rms_norm_f32(&mut x, 1e-5); + assert!(x.is_empty()); + } +} diff --git a/linalg/src/x86_64_fma/softmax.rs b/linalg/src/x86_64_fma/softmax.rs index ed63d3ca4e..283dbb55d0 100644 --- a/linalg/src/x86_64_fma/softmax.rs +++ b/linalg/src/x86_64_fma/softmax.rs @@ -1,3 +1,6 @@ +use crate::num_traits::Zero; +use tract_data::internal::f16; + map_reduce_impl_wrap!( f32, x86_64_fma_softmax2_fastcompact_f32_32n, @@ -119,3 +122,277 @@ mod test_x86_64_fma_softmax2_fastcompact_f32_32n { x86_64_fma_softmax2_fastcompact_f32_32n ); } + +// AVX-512 version: processes 64 f32 per loop iteration (4 zmm registers of 16 +// lanes each). Same fast-compact-exp algorithm as the FMA kernel above: +// y = bitcast_u32(max(0, SLOPE*(x-max) + OFFSET)) (via vcvttps2dq) +// then writes y back and accumulates sum(y). Runtime-gated on avx512f (see +// x86_64_fma.rs::plug_avx512f); non-AVX512 CPUs keep using the FMA kernel. +// nr=64, 64-byte (16xf32) alignment. +map_reduce_impl_wrap!( + f32, + x86_64_avx512_softmax2_fastcompact_f32_64n, + 64, + 16, + f32, + f32::MIN, + 0f32, + #[inline(never)] + fn run(buf: &mut [f32], max: f32) -> f32 { + assert!(buf.len() % 64 == 0); + assert!(buf.len() > 0); + unsafe { x86_64_avx512_softmax2_fastcompact_f32_64n_run(buf, max) } + }, + #[inline(never)] + fn reduce_two(a: f32, b: f32) -> f32 { + a + b + } +); + +#[target_feature(enable = "avx512f")] +unsafe fn x86_64_avx512_softmax2_fastcompact_f32_64n_run(buf: &mut [f32], max: f32) -> f32 { + unsafe { + let len = buf.len(); + let ptr = buf.as_ptr(); + let mut acc = 0f32; + const MLN2: f32 = 0.6931471805f32; + const A: f32 = 8388608.0f32; + const B: f32 = 1065353216.0f32; + const C: f32 = 60801.0f32; + const SLOPE: f32 = A / MLN2; + const OFFSET: f32 = B - C; + std::arch::asm!(" + vbroadcastss zmm0, xmm0 + vmovaps zmm1, zmm0 + vmovaps zmm2, zmm0 + vmovaps zmm3, zmm0 + + vpxord zmm28, zmm28, zmm28 // zero (clamp floor) + vbroadcastss zmm29, xmm29 // max + vbroadcastss zmm30, xmm30 // slope + vbroadcastss zmm31, xmm31 // offset + 2: + vmovaps zmm4, [{ptr}] + vmovaps zmm5, [{ptr} + 64] + vmovaps zmm6, [{ptr} + 128] + vmovaps zmm7, [{ptr} + 192] + + vsubps zmm4, zmm4, zmm29 + vsubps zmm5, zmm5, zmm29 + vsubps zmm6, zmm6, zmm29 + vsubps zmm7, zmm7, zmm29 + + vmovaps zmm8, zmm31 + vmovaps zmm9, zmm31 + vmovaps zmm10, zmm31 + vmovaps zmm11, zmm31 + + vfmadd231ps zmm8, zmm4, zmm30 + vfmadd231ps zmm9, zmm5, zmm30 + vfmadd231ps zmm10, zmm6, zmm30 + vfmadd231ps zmm11, zmm7, zmm30 + + vmaxps zmm8, zmm8, zmm28 + vmaxps zmm9, zmm9, zmm28 + vmaxps zmm10, zmm10, zmm28 + vmaxps zmm11, zmm11, zmm28 + + vcvttps2dq zmm8, zmm8 + vcvttps2dq zmm9, zmm9 + vcvttps2dq zmm10, zmm10 + vcvttps2dq zmm11, zmm11 + + vmovaps [{ptr}] , zmm8 + vmovaps [{ptr} + 64] , zmm9 + vmovaps [{ptr} + 128], zmm10 + vmovaps [{ptr} + 192], zmm11 + + vaddps zmm0, zmm0, zmm8 + vaddps zmm1, zmm1, zmm9 + vaddps zmm2, zmm2, zmm10 + vaddps zmm3, zmm3, zmm11 + + add {ptr}, 256 + sub {len}, 64 + jnz 2b + + vaddps zmm0, zmm0, zmm1 + vaddps zmm2, zmm2, zmm3 + vaddps zmm0, zmm0, zmm2 // zmm0 holds 16 partial sums + vextractf64x4 ymm1, zmm0, 1 // upper 256 bits (8xf32) -> ymm1 (avx512f) + vaddps ymm0, ymm0, ymm1 // ymm0 holds 8 values + vextractf128 xmm1, ymm0, 1 // upper 4xf32 -> xmm1 + vaddps xmm0, xmm0, xmm1 // xmm0 holds 4 values + vpermilps xmm1, xmm0, 2 + (3 << 2) + vaddps xmm0, xmm0, xmm1 // xmm0 holds 2 values + vpermilps xmm1, xmm0, 1 + vaddps xmm0, xmm0, xmm1 + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + inout("zmm0") acc, + out("zmm1") _, out("zmm2") _, out("zmm3") _, + out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _, + out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _, + out("zmm28") _, + inout("zmm29") max => _, + inout("zmm30") SLOPE => _, + inout("zmm31") OFFSET => _, + ); + acc + } +} + +#[cfg(test)] +mod test_x86_64_avx512_softmax2_fastcompact_f32_64n { + use super::*; + crate::softmax_l2_frame_tests!( + is_x86_feature_detected!("avx512f"), + f32, + x86_64_avx512_softmax2_fastcompact_f32_64n + ); +} + +// AVX-512 f16 softmax_l2: same fast-compact-exp algorithm as the FMA f32 +// kernel, with f16 <-> f32 conversion at the IO boundary. Each loop iteration +// handles 64 f16 (128 bytes) through 4× (ymm f16 load -> vcvtph2ps -> zmm f32 +// compute -> vcvttps2dq -> vcvtps2ph -> ymm f16 store). The sum is accumulated +// in f32 across the loop (higher precision than the generic HSoftMaxL2 which +// accumulates in f16) and cast to f16 at return; the SuperApproximate test +// tolerance covers the precision delta. +// nr=64 (multiple of 4 ymm f16 loads); alignment_items=32 (64-byte aligned). +map_reduce_impl_wrap!( + f16, + x86_64_avx512_softmax2_fastcompact_f16_64n, + 64, + 32, + f16, + f16::MIN, + f16::zero(), + #[inline(never)] + fn run(buf: &mut [f16], max: f16) -> f16 { + assert!(buf.len() % 64 == 0); + assert!(buf.len() > 0); + unsafe { x86_64_avx512_softmax2_fastcompact_f16_64n_run(buf, max) } + }, + #[inline(never)] + fn reduce_two(a: f16, b: f16) -> f16 { + a + b + } +); + +#[target_feature(enable = "avx512f")] +unsafe fn x86_64_avx512_softmax2_fastcompact_f16_64n_run( + buf: &mut [tract_data::internal::f16], + max: tract_data::internal::f16, +) -> tract_data::internal::f16 { + unsafe { + let len = buf.len(); + let ptr = buf.as_ptr(); + let max_f32: f32 = max.to_f32(); + let mut acc = 0f32; + const MLN2: f32 = 0.6931471805f32; + const A: f32 = 8388608.0f32; + const B: f32 = 1065353216.0f32; + const C: f32 = 60801.0f32; + const SLOPE: f32 = A / MLN2; + const OFFSET: f32 = B - C; + std::arch::asm!(" + vbroadcastss zmm0, xmm0 + vmovaps zmm1, zmm0 + vmovaps zmm2, zmm0 + vmovaps zmm3, zmm0 + + vpxord zmm28, zmm28, zmm28 // 0 (clamp floor) + vbroadcastss zmm29, xmm29 // max (f32) + vbroadcastss zmm30, xmm30 // slope + vbroadcastss zmm31, xmm31 // offset + 2: + // load 4 ymm of f16 (16 f16 per ymm = 32 bytes), convert to zmm f32 + vcvtph2ps zmm4, [{ptr}] + vcvtph2ps zmm5, [{ptr} + 32] + vcvtph2ps zmm6, [{ptr} + 64] + vcvtph2ps zmm7, [{ptr} + 96] + + // subtract max + vsubps zmm4, zmm4, zmm29 + vsubps zmm5, zmm5, zmm29 + vsubps zmm6, zmm6, zmm29 + vsubps zmm7, zmm7, zmm29 + + // OFFSET + SLOPE * (x - max) + vmovaps zmm8, zmm31 + vmovaps zmm9, zmm31 + vmovaps zmm10, zmm31 + vmovaps zmm11, zmm31 + vfmadd231ps zmm8, zmm4, zmm30 + vfmadd231ps zmm9, zmm5, zmm30 + vfmadd231ps zmm10, zmm6, zmm30 + vfmadd231ps zmm11, zmm7, zmm30 + + // max(0, ...) + vmaxps zmm8, zmm8, zmm28 + vmaxps zmm9, zmm9, zmm28 + vmaxps zmm10, zmm10, zmm28 + vmaxps zmm11, zmm11, zmm28 + + // fast-compact-exp trick: the truncated i32 has the same bit + // pattern as the f32 ~exp(x), so accumulate AS f32 + store as f16 + vcvttps2dq zmm8, zmm8 + vcvttps2dq zmm9, zmm9 + vcvttps2dq zmm10, zmm10 + vcvttps2dq zmm11, zmm11 + + vaddps zmm0, zmm0, zmm8 + vaddps zmm1, zmm1, zmm9 + vaddps zmm2, zmm2, zmm10 + vaddps zmm3, zmm3, zmm11 + + // convert back to f16 and store (4th operand 0 = round to nearest even) + vcvtps2ph [{ptr}], zmm8, 0 + vcvtps2ph [{ptr} + 32], zmm9, 0 + vcvtps2ph [{ptr} + 64], zmm10, 0 + vcvtps2ph [{ptr} + 96], zmm11, 0 + + add {ptr}, 128 + sub {len}, 64 + jnz 2b + + // reduce zmm0..3 to a scalar f32 in xmm0 + vaddps zmm0, zmm0, zmm1 + vaddps zmm2, zmm2, zmm3 + vaddps zmm0, zmm0, zmm2 + vextractf64x4 ymm1, zmm0, 1 + vaddps ymm0, ymm0, ymm1 + vextractf128 xmm1, ymm0, 1 + vaddps xmm0, xmm0, xmm1 + vpermilps xmm1, xmm0, 2 + (3 << 2) + vaddps xmm0, xmm0, xmm1 + vpermilps xmm1, xmm0, 1 + vaddps xmm0, xmm0, xmm1 + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + inout("zmm0") acc, + out("zmm1") _, out("zmm2") _, out("zmm3") _, + out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _, + out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _, + out("zmm28") _, + inout("zmm29") max_f32 => _, + inout("zmm30") SLOPE => _, + inout("zmm31") OFFSET => _, + ); + f16::from_f32(acc) + } +} + +#[cfg(test)] +mod test_x86_64_avx512_softmax2_fastcompact_f16_64n { + use super::*; + use tract_data::internal::f16; + crate::softmax_l2_frame_tests!( + is_x86_feature_detected!("avx512f"), + f16, + x86_64_avx512_softmax2_fastcompact_f16_64n + ); +} diff --git a/linalg/x86_64/avx512/sigmoid_f32.S.j2 b/linalg/x86_64/avx512/sigmoid_f32.S.j2 index 9b0c3c9df1..bc57bbb8c5 100644 --- a/linalg/x86_64/avx512/sigmoid_f32.S.j2 +++ b/linalg/x86_64/avx512/sigmoid_f32.S.j2 @@ -1,8 +1,9 @@ {# // vim: set syntax=asm : - -// TODO[TSolberg] : Not validated. +// AVX-512 (zmm, 16-wide) sigmoid. Uses a rational (Padé-style) approximation +// of sigmoid clamped to [-18, 18]. Validated against the generic scalar +// reference via sigmoid_frame_tests! (see x86_64_fma.rs). System V ABI: args: rdi, rsi, rdx, rcx, r8, r9 @@ -89,7 +90,7 @@ avx512_sigmoid_f32_{{suffix}} proc cmp rsi, 0 je {{L}}done - cmp rsi, 32 + cmp rsi, 64 jl {{L}}loop_1 {{L}}loop_4: @@ -194,8 +195,8 @@ avx512_sigmoid_f32_{{suffix}} proc vmovaps [rdi + 192], zmm7 add rdi, 256 - sub rsi, 32 - cmp rsi, 32 + sub rsi, 64 + cmp rsi, 64 jg {{L}}loop_4 cmp rsi, 0 @@ -243,8 +244,8 @@ avx512_sigmoid_f32_{{suffix}} proc vaddps zmm4, zmm4, zmm1 vmovaps [rdi], zmm4 - add rdi, 32 - sub rsi, 8 + add rdi, 64 + sub rsi, 16 jnz {{L}}loop_1 {{L}}done: diff --git a/linalg/x86_64/avx512/tanh_f32.S.j2 b/linalg/x86_64/avx512/tanh_f32.S.j2 index 65e01bd029..dd0325b201 100644 --- a/linalg/x86_64/avx512/tanh_f32.S.j2 +++ b/linalg/x86_64/avx512/tanh_f32.S.j2 @@ -1,7 +1,9 @@ {# // vim: set syntax=asm : -// TODO[TSolberg] : Not validated. +// AVX-512 (zmm, 16-wide) tanh. Polynomial-numerator / polynomial-denominator +// rational approximation of tanh clamped to [-9, 9]. Validated against the +// generic scalar reference via tanh_frame_tests! (see x86_64_fma.rs). System V ABI: args: rdi, rsi, rdx, rcx, r8, r9 @@ -88,7 +90,7 @@ avx512_tanh_f32_{{suffix}} proc cmp rsi, 0 je {{L}}done - cmp rsi, 32 + cmp rsi, 64 jl {{L}}loop_1 {{L}}loop_4: @@ -188,8 +190,8 @@ avx512_tanh_f32_{{suffix}} proc vmovaps [rdi + 192], zmm7 add rdi, 256 - sub rsi, 32 - cmp rsi, 32 + sub rsi, 64 + cmp rsi, 64 jg {{L}}loop_4 cmp rsi, 0 @@ -235,8 +237,8 @@ avx512_tanh_f32_{{suffix}} proc vdivps zmm4, zmm4, zmm12 vmovaps [rdi], zmm4 - add rdi, 32 - sub rsi, 8 + add rdi, 64 + sub rsi, 16 jnz {{L}}loop_1 {{L}}done: diff --git a/linalg/x86_64/avx512vnni/dummy_vnni.S b/linalg/x86_64/avx512vnni/dummy_vnni.S new file mode 100644 index 0000000000..293f07fd4c --- /dev/null +++ b/linalg/x86_64/avx512vnni/dummy_vnni.S @@ -0,0 +1,13 @@ +// Build-time capability probe for the assembler, used by build.rs +// (assembler_supports_avx512vnni). Older binutils — notably the Debian stretch +// x86_64 toolchain in CI — predate AVX-512 VNNI (added in binutils ~2.30) and +// cannot assemble `vpdpbusd ymm` even when targeting a VNNI-capable CPU. If +// this file fails to assemble, build.rs skips the VNNI kernel and the +// `tract_avx512vnni` cfg, and the runtime falls back to the AVX2 i32 path. +// Not linked into anything. +.intel_syntax noprefix +.text +.globl tract_avx512vnni_probe +tract_avx512vnni_probe: + vpdpbusd ymm0, ymm1, ymm2 + ret diff --git a/linalg/x86_64/fma/avx512vnni_mmm_i32_8x8.S.j2 b/linalg/x86_64/fma/avx512vnni_mmm_i32_8x8.S.j2 new file mode 100644 index 0000000000..ed5d4540af --- /dev/null +++ b/linalg/x86_64/fma/avx512vnni_mmm_i32_8x8.S.j2 @@ -0,0 +1,676 @@ +{# +// vim: set syntax=asm : + +/* mmm 8x8: + + ymm0 ymm1 ymm2 ymm3 ymm4 ymm5 ymm6 ymm7 + +System V ABI: + args: rdi, rsi, rdx, rcx, r8, r9 + preserve: rbx, rsp, rbp, r12, r13, r14, r15 + scratch: rax, rdi, rsi, rdx, rcx, r8, r9, r10, r11 + return: rax (+rdx) + +Windows ABI: + args: RCX, RDX, R8, R9 + preserve: RBX, RBP, RDI, RSI, RSP, R12, R13, R14, R15, and XMM6-15 + scratch: RAX, RCX, RDX, R8, R9, R10, R11, XMM0-5, and the upper portions of YMM0-15 and ZMM0-15 + return: rax (+rdx) +*/ +#} + +{% if msvc %} + +_text segment +avx512vnni_mmm_i32_8x8_{{suffix}} proc + +{% else %} + +.intel_syntax noprefix +.text +.p2align 5 +.globl {{G}}avx512vnni_mmm_i32_8x8_{{suffix}} +{{G}}avx512vnni_mmm_i32_8x8_{{suffix}}: +.cfi_startproc + +{% endif %} + + push rbp + mov rbp, rsp + +{% if family == "windows" %} +// https://www.agner.org/optimize/calling_conventions.pdf xmm6-15 are not scratch +// https://stackoverflow.com/questions/43358429/save-value-of-xmm-registers + and rsp,-16 + lea rsp,[rsp-160] + vmovaps [rsp], xmm6 + vmovaps [rsp+16*1],xmm7 + vmovaps [rsp+16*2],xmm8 + vmovaps [rsp+16*3],xmm9 + vmovaps [rsp+16*4],xmm10 + vmovaps [rsp+16*5],xmm11 + vmovaps [rsp+16*6],xmm12 + vmovaps [rsp+16*7],xmm13 + vmovaps [rsp+16*8],xmm14 + vmovaps [rsp+16*9],xmm15 + + push rdi + push rsi + + mov rdi, rcx + +{% endif %} + + push rbx + push r12 + push r13 + push r14 + push r15 + + sub rsp, 8 + +{% if family == "unix" %} +.cfi_def_cfa_offset 64 +{% endif %} + + stmxcsr [rsp + 4] +{% if msvc %} + mov rax, 1FC0h +{% else %} + mov rax, 0x1FC0 +{% endif %} + mov [rsp], eax + ldmxcsr [rsp] + +{% include "dispatcher.j2" %} + +{{L}}clear: + vzeroall + jmp {{L}}non_linear_loop + +{{L}}add_mat_mul: + mov r12, [rdi + 32] // packing + mov rbx, [rdi + 24] // B + mov rax, [rdi + 16] // A + + mov rcx, [rdi + 8] // k + test rcx, rcx + jz {{L}}non_linear_loop + + cmp r12, 1 + je {{L}}main_loop_packed_packed_i8i8 + +{{L}}main_loop_packed_packed: + vmovaps ymm12, [rax] + + {% for i in range(0, 8) %} + vbroadcastss ymm14, dword ptr [rbx + {{i}} * 4] + vpmulld ymm13, ymm12, ymm14 + vpaddd ymm{{i}}, ymm{{i}}, ymm13 + {% endfor %} + + add rax, 32 + add rbx, 32 + dec rcx + jnz {{L}}main_loop_packed_packed + + jmp {{L}}non_linear_loop + +{{L}}main_loop_packed_packed_i8i8: + // PackedI8K4 layout: per K=4 block, the A panel is 8 rows x 4 K-bytes (32 + // bytes, lane m = A[m, 4kb..4kb+3]) and the B panel is 8 cols x 4 K-bytes + // (lane n = B[n, 4kb..4kb+3]). VPDPBUSD is u8 x s8, so A is offset by +128 + // (-> u8) and the resulting 128*sum_k(B[n]) bias is removed per column after + // the loop, leaving the i32 accumulators identical to the AVX2 path. + + add rcx, 3 + shr rcx, 2 // rcx <- ceil(k/4) K=4 blocks + + mov r8d, 0x01010101 + movd xmm11, r8d + vpbroadcastd ymm11, xmm11 // ymm11 <- u8 ones (sum of B) + + mov r8d, 0x80808080 + movd xmm12, r8d + vpbroadcastd ymm12, xmm12 // ymm12 <- byte 0x80 (A + 128) + + vpxor ymm10, ymm10, ymm10 // ymm10 <- per-col sum_k B[n] + +{{L}}loop_4k_i8i8: + vmovdqu ymm8, [rax] // A block: lane m = A[m,4kb..] + vpaddb ymm8, ymm8, ymm12 // s8 -> u8 (+128, modular) + + vmovdqu ymm9, [rbx] // B block: lane n = B[n,4kb..] + vpdpbusd ymm10, ymm11, ymm9 // sum_k B[n] += sum_t B[n,4kb+t] + + {% for n in range(0, 8) %} + vpbroadcastd ymm13, dword ptr [rbx + {{n}} * 4] + vpdpbusd ymm{{n}}, ymm8, ymm13 // acc[n][m] += sum_t (A[m]+128)*B[n] + {% endfor %} + + add rax, 32 + add rbx, 32 + dec rcx + jnz {{L}}loop_4k_i8i8 + + // remove the +128 bias added on A: acc[n] -= 128 * sum_k B[n] + vpslld ymm10, ymm10, 7 // lane n <- 128 * sum_k B[n] + {% for n in range(0, 8) %} + mov r8d, {{n}} + movd xmm14, r8d + vpbroadcastd ymm14, xmm14 // index = n in every lane + vpermd ymm15, ymm14, ymm10 // splat 128*sum_k B[n] + vpsubd ymm{{n}}, ymm{{n}}, ymm15 + {% endfor %} + + jmp {{L}}non_linear_loop + +{% set from = 0 %}{% set to = 7 %}{% include "fma_mmm_i32_scalars.j2" %} +{% set mr = 8 %}{% set from = 0 %}{% set to = 7 %}{% include "fma_mmm_i32_per_rows.j2" %} +{% set mr = 8 %}{% set from = 0 %}{% set to = 7 %}{% include "fma_mmm_i32_per_cols.j2" %} +{% set from = 0 %}{% set to = 7 %}{% include "fma_mmm_load_tile.j2" %} + +{{L}}add_unicast: + + mov r10, [rdi + 8] // c ptr + mov rsi, [rdi + 16] // row stride + mov rbx, [rdi + 24] // col stride + mov r8, [rdi + 32] // item size + + cmp r8, 4 + je {{L}}non_linear_addc_i32 + +{# +// This is not great as vgatherdps reads 32-bits values and goes beyond our buffer. Probably harmless though. +// Commented and replaced with the "mov al" loop beyond to pacify valgrind. +// ymm14 and ymm15 are the same as in the non_linear_addc_i32 case (compute them before the test right above here. +// {% for i in range(0, 8) %} +// vpcmpeqd ymm15, ymm15, ymm15 +// vgatherdps ymm12, [ r10 + ymm14 ], ymm15 // 0xxx 1xxx 2xxx 3xxx 4xxx 5xxx 6xxx 7xxx +// +// // we need to go through vpmovsxbd, shuffling naively erases signs +// vpshufb ymm12, ymm12, ymm10 // 0123 0123 0123 0123 4567 4567 4567 4567 +// +// vpermd ymm12, ymm11, ymm12 // 0123 4567 +// vpmovsxbd ymm12, xmm12 // sign extend +// +// vpaddd ymm{{i}}, ymm{{i}}, ymm12 +// add r10, rbx +// {% endfor %} +#} + + {% for col in range(0, 8) %} + mov r8, r10 + {% for half in range(0, 2) %} + {% for lane in range(0, 4) %} + mov al, [ r8 ] + add r8, rsi + movsx eax, al + pinsrd xmm10, eax, {{lane}} + {% endfor %} + vperm2f128 ymm10, ymm10, ymm10, 1 + {% endfor %} + vpaddd ymm{{col}}, ymm{{col}}, ymm10 + add r10, rbx + {% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}non_linear_addc_i32: + + mov eax, 0 +{% for i in range(0, 4) %} + pinsrd xmm14, eax, {{i}} + add eax, esi +{% endfor %} + vpermq ymm14, ymm14, 78 // 0b01001110 +{% for i in range(0, 4) %} + pinsrd xmm14, eax, {{i}} + add eax, esi +{% endfor %} + vpermq ymm14, ymm14, 78 // 0b01001110 + + +{% if msvc %} + vpbroadcastd ymm10, dword ptr [ offset byte_shuffle ] + vmovups ymm11, dword ptr [ offset i128_shuffle ] +{% else %} + vpbroadcastd ymm10, [ rip + {{L}}byte_shuffle ] + vmovups ymm11, [ rip + {{L}}i128_shuffle ] +{% endif %} + +{% for i in range(0, 8) %} + vpcmpeqd ymm15, ymm15, ymm15 + vgatherdps ymm12, [ r10 + ymm14 ], ymm15 + vpaddd ymm{{i}}, ymm{{i}}, ymm12 + add r10, rbx +{% endfor %} + + jmp {{L}}non_linear_loop + +{% if msvc %} +.data +byte_shuffle dd 201851904 // 0x0c080400 +i128_shuffle dd 0, 4 +.code +{% else %} +{{L}}byte_shuffle: .int 201851904 // 0x0c080400 +{{L}}i128_shuffle: .int 0, 4 +{% endif %} + +{{L}}add_row_col_products: + mov rax, [ rdi + 8 ] + mov rbx, [ rdi + 16 ] + + vmovups ymm12, [rax] + +{% for i in range(0, 8) %} + vbroadcastss ymm14, dword ptr [rbx + {{ i * 4 }} ] + vpmulld ymm15, ymm12, ymm14 + vpaddd ymm{{i}}, ymm{{i}}, ymm15 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale: + mov r8, [ rdi + 16 ] // policy + vbroadcastss ymm8, dword ptr [rdi + 24] // multi + + mov rax, 1 + movq xmm9, rax + vpbroadcastq ymm9, xmm9 // ymm9 <- 1 + + mov rax, [ rdi + 8 ] // xmm10 <- shift + 31 + add rax, 31 + movq xmm10, rax + vpbroadcastq ymm10, xmm10 + + mov rax, 1 + movq xmm11, rax + vpsubq ymm12, ymm10, ymm9 // shift+31 - 1 + vpsllq ymm11, ymm9, xmm12 // ymm11 <- 1 << (shift + 31 - 1) + + cmp r8, 1 + je {{L}}q_scale_rounding_zero + cmp r8, 2 + je {{L}}q_scale_rounding_away + cmp r8, 3 + je {{L}}q_scale_rounding_minus_inf + cmp r8, 4 + je {{L}}q_scale_rounding_plus_inf + cmp r8, 5 + je {{L}}q_scale_rounding_even + cmp r8, 6 + je {{L}}q_scale_rounding_odd + + jmp {{L}}unsupported + +{{L}}q_scale_rounding_zero: // signum * ( (abs + nudge) >> shift ) +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + vpsubq ymm14, ymm14, ymm9 + vpsubq ymm15, ymm15, ymm9 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_away: // signum * ( (abs + nudge) >> shift ) +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_minus_inf: // signum * ( (abs << 32 + 1<<30+shift) >> shift ) +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + // sign extract for nudging in the right direction + vpxor ymm13, ymm13, ymm13 + vpcmpgtd ymm13, ymm{{i}}, ymm13 // ymm13 <- s0, s1, ..s8 (signums, as all ones or all zeros) + vpsrld ymm13, ymm13, 31 // then just 0 or 1 + + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + // reinterpret ymm13=s0i32..s7 as i64 and blend with zero to pick the even ones as i64 + vpxor ymm12, ymm12, ymm12 + vpblendd ymm12, ymm12, ymm13, 85 // 0x55 + vpsubq ymm14, ymm14, ymm12 + + vpsrldq ymm13, ymm13, 4 // ymm13 <- s1, s2, .., s7, 0 + vpxor ymm12, ymm12, ymm12 + vpblendd ymm12, ymm12, ymm13, 85 // 0x55 + vpsubq ymm15, ymm15, ymm12 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_plus_inf: // signum * ( (abs << 32 + 1<<30+shift) >> shift ) + + vpbroadcastd ymm9, xmm9 + +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpxor ymm13, ymm13, ymm13 + + // sign extract for nudging in the right direction + vpcmpgtd ymm13, ymm{{i}}, ymm13 // ymm13 <- s0, s1, ..s8 (signums, as all ones or all zeros) + vpaddd ymm13, ymm13, ymm9 // if val >= 0 { 0i32 } else { 1i32 } + + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + // reinterpret ymm13=s0i32..s7 as i64 and blend with zero to pick the even ones as i64 + vpxor ymm12, ymm12, ymm12 + vpblendd ymm12, ymm12, ymm13, 85 // 0x55 + vpsubq ymm14, ymm14, ymm12 + + vpsrldq ymm13, ymm13, 4 // ymm13 <- s1, s2, .., s7, 0 + vpxor ymm12, ymm12, ymm12 + vpblendd ymm12, ymm12, ymm13, 85 // 0x55 + vpsubq ymm15, ymm15, ymm12 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_even: // signum * ( (abs + nudge) >> shift ) +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpsrlq ymm12, ymm14, xmm10 + vpand ymm12, ymm12, ymm9 + vpaddq ymm14, ymm14, ymm12 + vpsubq ymm14, ymm14, ymm9 + + vpsrlq ymm12, ymm15, xmm10 + vpand ymm12, ymm12, ymm9 + vpaddq ymm15, ymm15, ymm12 + vpsubq ymm15, ymm15, ymm9 + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_odd: // signum * ( (abs + nudge) >> shift ) +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpsrlq ymm12, ymm14, xmm10 + vpand ymm12, ymm12, ymm9 + vpsubq ymm14, ymm14, ymm12 + + vpsrlq ymm12, ymm15, xmm10 + vpand ymm12, ymm12, ymm9 + vpsubq ymm15, ymm15, ymm12 + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}q_shl: + mov eax, [ rdi + 8 ] // xmm10 <- -shift (8 times) + movd xmm10, eax + vpbroadcastd ymm10, xmm10 + +{% for i in range(0, 8) %} + vpsllvd ymm{{i}}, ymm{{i}}, ymm10 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr: + mov r8, [ rdi + 16 ] // policy + + mov eax, 1 + movd xmm9, eax + vpbroadcastd ymm9, xmm9 // ymm9 <- 1u32 (8 times) + + mov eax, [ rdi + 8 ] // xmm10 <- shift (8 times) + movd xmm10, eax + vpbroadcastd ymm10, xmm10 + + mov ebx, 1 + mov cl, al + sub cl, 1 // rcx <- shift -1 + sal ebx, cl // rbx <- (1 << (shift - 1)) + movd xmm11, ebx + vpbroadcastd ymm11, xmm11 // ymm11 <- "half" + + vpxor ymm12, ymm12, ymm12 // ymm12 <- zeroes + + cmp r8, 1 + je {{L}}q_shr_rounding_zero + cmp r8, 2 + je {{L}}q_shr_rounding_away + cmp r8, 3 + je {{L}}q_shr_rounding_minus_inf + cmp r8, 4 + je {{L}}q_shr_rounding_plus_inf + cmp r8, 5 + je {{L}}q_shr_rounding_even + cmp r8, 6 + je {{L}}q_shr_rounding_odd + + jmp {{L}}unsupported + +{{L}}q_shr_rounding_zero: +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsubd ymm14, ymm14, ymm9 + vpaddd ymm14, ymm14, ymm11 + vpsravd ymm14, ymm14, ymm10 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_away: +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpaddd ymm14, ymm14, ymm11 + vpsravd ymm14, ymm14, ymm10 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_minus_inf: +{% for i in range(0, 8) %} + vpsubd ymm{{i}}, ymm{{i}}, ymm9 + vpaddd ymm{{i}}, ymm{{i}}, ymm11 + vpsravd ymm{{i}}, ymm{{i}}, ymm10 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_plus_inf: +{% for i in range(0, 8) %} + vpaddd ymm{{i}}, ymm{{i}}, ymm11 + vpsravd ymm{{i}}, ymm{{i}}, ymm10 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_even: +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsravd ymm13, ymm14, ymm10 + vpand ymm13, ymm13, ymm9 + vpsubd ymm13, ymm13, ymm9 // nudge = ((abs >>l shift) & 0x01) - 1 + vpaddd ymm14, ymm14, ymm13 // add nudge + vpaddd ymm14, ymm14, ymm11 // add half + vpsravd ymm14, ymm14, ymm10 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_odd: +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsravd ymm13, ymm14, ymm10 + vpand ymm13, ymm13, ymm9 + vpsubd ymm13, ymm12, ymm13 // nudge = - ((abs >>l shift) & 0x01) + vpaddd ymm14, ymm14, ymm13 // add nudge + vpaddd ymm14, ymm14, ymm11 // add half + vpsravd ymm14, ymm14, ymm10 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}store: + mov r8, [rdi + 8] // c ptr + mov rsi, [rdi + 16] // row stride + mov rdx, [rdi + 24] // col stride + mov rcx, [rdi + 32] // item size + + cmp rcx, 4 + je {{L}}store_strides_i32 + + {% for col in range(0, 8) %} + mov r10, r8 + {% for row in range(0, 4) %} + extractps ebx, xmm{{col}}, {{row}} + mov byte ptr [r10], bl + add r10, rsi + {% endfor %} + vperm2f128 ymm{{col}}, ymm{{col}}, ymm{{col}}, 1 + {% for row in range(0, 4) %} + extractps ebx, xmm{{col}}, {{row}} + mov byte ptr [r10], bl + add r10, rsi + {% endfor %} + add r8, rdx + {% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}store_strides_i32: + {% for col in range(0, 8) %} + mov r10, r8 + {% for row in range(0, 4) %} + extractps ebx, xmm{{col}}, {{row}} + mov dword ptr [r10], ebx + add r10, rsi + {% endfor %} + vperm2f128 ymm{{col}}, ymm{{col}}, ymm{{col}}, 1 + {% for row in range(0, 4) %} + extractps ebx, xmm{{col}}, {{row}} + mov dword ptr [r10], ebx + add r10, rsi + {% endfor %} + add r8, rdx + {% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}return: + ldmxcsr [rsp + 4] + add rsp, 8 + + pop r15 + pop r14 + pop r13 + pop r12 + pop rbx + +{% if family == "windows" %} + pop rsi + pop rdi + + vmovaps xmm15, [rsp+16*9] + vmovaps xmm14, [rsp+16*8] + vmovaps xmm13, [rsp+16*7] + vmovaps xmm12, [rsp+16*6] + vmovaps xmm11, [rsp+16*5] + vmovaps xmm10, [rsp+16*4] + vmovaps xmm9, [rsp+16*3] + vmovaps xmm8, [rsp+16*2] + vmovaps xmm7, [rsp+16*1] + vmovaps xmm6, [rsp] +{% endif %} + + mov rsp, rbp + pop rbp + ret + + +{{L}}one_32bit: +{% if msvc %} + dd 1 +{% else %} + .int 1 +{% endif %} + +{% if msvc %} +avx512vnni_mmm_i32_8x8_{{suffix}} endp +_text ends +end +{% else %} +.cfi_endproc +{% endif %} diff --git a/metal/Cargo.toml b/metal/Cargo.toml index 17ffd1b341..43103d21b2 100644 --- a/metal/Cargo.toml +++ b/metal/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tract-metal" -version = "0.23.0-pre" +version = "0.23.1-pre" license = "MIT OR Apache-2.0" authors = [ "Hubert de La Jonquière ", diff --git a/nnef/Cargo.toml b/nnef/Cargo.toml index 2051d393b5..982057aa38 100644 --- a/nnef/Cargo.toml +++ b/nnef/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tract-nnef" -version = "0.23.0-pre" +version = "0.23.1-pre" authors = ["Mathieu Poumeyrol "] license = "MIT OR Apache-2.0" description = "Tiny, no-nonsense, self contained, TensorFlow and ONNX inference" diff --git a/nnef/nnef-resources/Cargo.toml b/nnef/nnef-resources/Cargo.toml index 03a6838d85..8220ba44e0 100644 --- a/nnef/nnef-resources/Cargo.toml +++ b/nnef/nnef-resources/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tract-nnef-resources" -version = "0.23.0-pre" +version = "0.23.1-pre" authors = [ "Mathieu Poumeyrol ", "Hubert de La Jonquière " diff --git a/nnef/src/transform.rs b/nnef/src/transform.rs index 414f0a2a5a..186b0d3530 100644 --- a/nnef/src/transform.rs +++ b/nnef/src/transform.rs @@ -87,7 +87,6 @@ impl ModelTransform for PatchTransform { // Build the TypedModelPatch let mut patch = TypedModelPatch { model: patch_model, taps, ..TypedModelPatch::default() }; - let mut inputs_to_remove = vec![]; let mut new_output_names = vec![]; for (i, lhs_name) in lhs_names.iter().enumerate() { @@ -124,15 +123,19 @@ impl ModelTransform for PatchTransform { &[wire], )?[0]; } + // Shunt consumers of the input outlet to the new wire. Like + // `TypedModelPatch::shunt_outside` we do NOT touch + // `model.inputs` — the original input stays in the list as + // a disconnected Source. Use `select_inputs(...)` after the + // patch to drop it explicitly. patch.shunt_outside(model, input_outlet, wire)?; - inputs_to_remove.push(input_outlet); } else if self.0.new_outputs.contains(lhs_name) { new_output_names.push(lhs_name.clone()); } else { let is_intermediate = i < lhs_names.len() - 1; if !is_intermediate { bail!( - "Wire '{}' is not a model input and not declared in new_outputs", + "Wire '{}' is not a model input/intermediate and not declared in new_outputs", lhs_name ); } @@ -141,9 +144,6 @@ impl ModelTransform for PatchTransform { patch.apply(model)?; - for inp in &inputs_to_remove { - model.inputs.retain(|o| o != inp); - } for name in &new_output_names { let node_id = model.node_id_by_name(name)?; model.outputs.push(OutletId::new(node_id, 0)); diff --git a/onnx-opl/Cargo.toml b/onnx-opl/Cargo.toml index ca43b47230..de71ad7f0e 100644 --- a/onnx-opl/Cargo.toml +++ b/onnx-opl/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tract-onnx-opl" -version = "0.23.0-pre" +version = "0.23.1-pre" authors = ["Mathieu Poumeyrol "] license = "MIT OR Apache-2.0" description = "Tiny, no-nonsense, self contained, TensorFlow and ONNX inference" diff --git a/onnx/Cargo.toml b/onnx/Cargo.toml index 8623cf64bd..7e36dbec44 100644 --- a/onnx/Cargo.toml +++ b/onnx/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tract-onnx" -version = "0.23.0-pre" +version = "0.23.1-pre" authors = ["Mathieu Poumeyrol "] license = "MIT OR Apache-2.0" description = "Tiny, no-nonsense, self contained, TensorFlow and ONNX inference" diff --git a/onnx/src/ops/nn/group_query_attention.rs b/onnx/src/ops/nn/group_query_attention.rs new file mode 100644 index 0000000000..f7ff46f0a1 --- /dev/null +++ b/onnx/src/ops/nn/group_query_attention.rs @@ -0,0 +1,210 @@ +use crate::model::ParsingContext; +use crate::pb::NodeProto; +use tract_core::ops::change_axes::AxisOp; +use tract_hir::internal::*; +use tract_transformers::ops::sdpa::Sdpa; + +// com.microsoft GroupQueryAttention (prefill only). +// inputs: query(0), key(1), value(2), past_key(3), past_value(4), seqlens_k(5), total_seq(6) +// outputs: output(0), present_key(1), present_value(2) +// Scoped to the prefill case (no past KV cache): query/key/value are [B, S, heads*head_size], +// attention is causal with q_seq == kv_seq, and present_key/value are the reshaped K/V. +// Sliding-window attention (local_window_size) is applied as a banded causal mask. +// Decode-step KV cache and internal rotary (do_rotary) are still rejected. +pub fn group_query_attention( + _ctx: &ParsingContext, + node: &NodeProto, +) -> TractResult<(Box, Vec)> { + let num_heads: usize = node.get_attr("num_heads")?; + let kv_num_heads: usize = node.get_attr("kv_num_heads")?; + let scale = node.get_attr_opt::("scale")?; + ensure!( + node.get_attr_opt::("do_rotary")?.unwrap_or(0) == 0, + "GroupQueryAttention: internal rotary (do_rotary) is unsupported; apply RotaryEmbedding separately" + ); + // Sliding-window attention: a query attends only to the last `window` keys (in + // addition to causal). <0 / absent means no window (full causal). Applied as a band. + let window = node.get_attr_opt::("local_window_size")?.unwrap_or(0).max(0) as usize; + ensure!( + node.get_attr_opt::("softcap")?.unwrap_or(0.0) == 0.0, + "GroupQueryAttention: softcap is unsupported" + ); + let have_past = (node.input.len() > 3 && !node.input[3].is_empty()) + || (node.input.len() > 4 && !node.input[4].is_empty()); + ensure!( + !have_past, + "GroupQueryAttention: past KV cache (decode step) is unsupported; only prefill is handled" + ); + Ok((expand(GroupQueryAttention { num_heads, kv_num_heads, scale, window }), vec![])) +} + +#[derive(Debug, Clone)] +struct GroupQueryAttention { + num_heads: usize, + kv_num_heads: usize, + scale: Option, + /// Sliding-window size; 0 = full causal (no window). + window: usize, +} + +/// Additive attention mask for causal + optional sliding window: entry `(i, j)` is `0` +/// when query `i` may attend to key `j` (`j <= i`, and `i - j < window` when `window > 0`), +/// else `-inf`. `window == 0` is plain causal. +fn windowed_causal_mask(qs: usize, ks: usize, window: usize) -> tract_ndarray::Array2 { + tract_ndarray::Array2::::from_shape_fn((qs, ks), |(i, j)| { + if j <= i && (window == 0 || i - j < window) { 0.0f32 } else { f32::NEG_INFINITY } + }) +} + +// [B, S, heads*head_size] -> [B, heads, S, head_size] +fn to_4d( + model: &mut TypedModel, + prefix: &str, + x: OutletId, + total: TDim, + heads: usize, +) -> TractResult { + let head_dim = total.clone() / heads; + let reshaped = model.wire_node( + format!("{prefix}.reshape"), + AxisOp::Reshape(2, tvec![total], tvec![heads.to_dim(), head_dim]), + &[x], + )?[0]; + Ok(model.wire_node(format!("{prefix}.transpose"), AxisOp::Move(2, 1), &[reshaped])?[0]) +} + +impl Expansion for GroupQueryAttention { + fn name(&self) -> StaticName { + "GroupQueryAttention".into() + } + + fn nboutputs(&self) -> TractResult { + Ok(3) + } + + fn rules<'r, 'p: 'r, 's: 'r>( + &'s self, + s: &mut Solver<'r>, + inputs: &'p [TensorProxy], + outputs: &'p [TensorProxy], + ) -> InferenceResult { + check_output_arity(outputs, 3)?; + s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?; + s.equals(&inputs[0].shape, &outputs[0].shape)?; + s.equals(&inputs[0].datum_type, &outputs[1].datum_type)?; + s.equals(&inputs[0].datum_type, &outputs[2].datum_type)?; + // present_key / present_value = key/value reshaped to [B, kv_num_heads, S, head_dim]. + let kvh = self.kv_num_heads; + s.given(&inputs[1].shape, move |s, ks| { + s.equals( + &outputs[1].shape, + tvec![ks[0].clone(), kvh.to_dim(), ks[1].clone(), ks[2].clone() / kvh], + ) + })?; + s.given(&inputs[2].shape, move |s, vs| { + s.equals( + &outputs[2].shape, + tvec![vs[0].clone(), kvh.to_dim(), vs[1].clone(), vs[2].clone() / kvh], + ) + })?; + Ok(()) + } + + fn wire( + &self, + prefix: &str, + model: &mut TypedModel, + inputs: &[OutletId], + ) -> TractResult> { + let q_fact = model.outlet_fact(inputs[0])?.clone(); + let dt = q_fact.datum_type; + ensure!(q_fact.rank() == 3, "GroupQueryAttention: expected 3D query [B, S, hidden]"); + let q_hidden = q_fact.shape[2].clone(); + let k_hidden = model.outlet_fact(inputs[1])?.shape[2].clone(); + let v_hidden = model.outlet_fact(inputs[2])?.shape[2].clone(); + + let q4 = to_4d(model, &format!("{prefix}.q"), inputs[0], q_hidden.clone(), self.num_heads)?; + let k4 = to_4d(model, &format!("{prefix}.k"), inputs[1], k_hidden, self.kv_num_heads)?; + let v4 = to_4d(model, &format!("{prefix}.v"), inputs[2], v_hidden, self.kv_num_heads)?; + + // Causal (+ optional sliding-window) mask: materialise an explicit additive band + // for concrete shapes — query i attends to keys j with `j <= i` (causal) and, if a + // window is set, `i - j < window` (the last `window` keys). Fall back to Sdpa's own + // is_causal for symbolic shapes (a static window can't be built there). Sdpa handles + // GQA head grouping (kv heads < q heads). + let q_seq = model.outlet_fact(q4)?.shape[2].to_usize().ok(); + let kv_seq = model.outlet_fact(k4)?.shape[2].to_usize().ok(); + let window = self.window; + let (mask, is_causal) = if let (Some(qs), Some(ks)) = (q_seq, kv_seq) { + let arr = windowed_causal_mask(qs, ks, window); + let mask_tensor: Tensor = arr.into(); + let mut m = model.add_const(format!("{prefix}.causal_mask"), mask_tensor)?; + for i in 0..2 { + m = model.wire_node( + format!("{prefix}.mask_unsqueeze_{i}"), + AxisOp::Add(0), + &[m], + )?[0]; + } + (Some(m), false) + } else { + ensure!( + window == 0, + "GroupQueryAttention: sliding window (local_window_size) requires static \ + sequence lengths to materialise the banded mask" + ); + (None, true) + }; + let mut sdpa_inputs = tvec![q4, k4, v4]; + if let Some(m) = mask { + sdpa_inputs.push(m); + } + let sdpa = Sdpa { + scale: self.scale.map(tensor0), + datum_type: dt, + acc_datum_type: DatumType::F32, + is_causal, + }; + let y4 = model.wire_node(format!("{prefix}.sdpa"), sdpa, &sdpa_inputs)?[0]; + + // [B, num_heads, S, head_dim] -> [B, S, num_heads, head_dim] -> [B, S, hidden] + let y_t = model.wire_node(format!("{prefix}.y_transpose"), AxisOp::Move(1, 2), &[y4])?[0]; + let yf = model.outlet_fact(y4)?.clone(); + let (heads_dim, head_dim) = (yf.shape[1].clone(), yf.shape[3].clone()); + let y = model.wire_node( + format!("{prefix}.y_reshape"), + AxisOp::Reshape( + 2, + tvec![heads_dim.clone(), head_dim.clone()], + tvec![heads_dim * head_dim], + ), + &[y_t], + )?[0]; + + Ok(tvec!(y, k4, v4)) + } +} + +#[cfg(test)] +mod tests { + use super::windowed_causal_mask; + + #[test] + fn band_mask_causal_and_window() { + // window 3 on a 5x5: query i attends to key j iff causal (j<=i) AND i-j<3 + let m = windowed_causal_mask(5, 5, 3); + for i in 0..5 { + for j in 0..5 { + let want_open = j <= i && i - j < 3; + assert_eq!(m[(i, j)] == 0.0, want_open, "window=3 at (i={i}, j={j})"); + } + } + // window 0 == plain causal + let c = windowed_causal_mask(4, 4, 0); + for i in 0..4 { + for j in 0..4 { + assert_eq!(c[(i, j)] == 0.0, j <= i, "causal at (i={i}, j={j})"); + } + } + } +} diff --git a/onnx/src/ops/nn/layer_norm.rs b/onnx/src/ops/nn/layer_norm.rs index c84018e86c..c7daa62b56 100644 --- a/onnx/src/ops/nn/layer_norm.rs +++ b/onnx/src/ops/nn/layer_norm.rs @@ -161,20 +161,20 @@ impl Expansion for LayerNorm { let normalized = model.wire_node(format!("{prefix}.normalized"), mul(), &[d[0], inv_std_dev[0]])?; // NormalizedScaled = Mul(Normalized, Scale) Y = Add(NormalizedScaled, B) - let cast_normalized = model.wire_node( - format!("{prefix}.cast_normalized"), - cast(fact.datum_type), - &normalized, - )?; + // Keep `normalized` in self.datum_type (typically F32) through the + // scale/bias application, then cast back to fact.datum_type at the + // very end. Casting back too early causes a type mismatch with + // `cast_scale` / `cast_bias` (which are in self.datum_type), which + // breaks F16-input LayerNorm (output is F32 but fact says F16). let normalized_scaled = wire_with_rank_broadcast( format!("{prefix}.normalized_scaled"), model, mul(), - &[cast_normalized[0], cast_scale[0]], + &[normalized[0], cast_scale[0]], )?; - let y = if let Some(bias) = cast_bias { + let y_internal = if let Some(bias) = cast_bias { wire_with_rank_broadcast( - format!("{prefix}.y"), + format!("{prefix}.y_internal"), model, add(), &[normalized_scaled[0], bias[0]], @@ -182,6 +182,7 @@ impl Expansion for LayerNorm { } else { normalized_scaled }; + let y = model.wire_node(format!("{prefix}.cast_y"), cast(fact.datum_type), &y_internal)?; let mut outputs = tvec!(y[0]); if self.mean_output.is_some() { outputs.push(reduced_mean_x[0]); diff --git a/onnx/src/ops/nn/mod.rs b/onnx/src/ops/nn/mod.rs index ad6056574b..c0434a6365 100644 --- a/onnx/src/ops/nn/mod.rs +++ b/onnx/src/ops/nn/mod.rs @@ -14,6 +14,7 @@ mod dropout; mod gelu; mod gelu_contrib; mod group_norm; +mod group_query_attention; mod instance_norm; mod layer_norm; mod lp_norm; @@ -93,6 +94,7 @@ pub fn register_all_ops(reg: &mut OnnxOpRegister) { reg.insert("BiasGelu", gelu_contrib::bias_gelu); reg.insert("FastGelu", gelu_contrib::fast_gelu); reg.insert("QuickGelu", gelu_contrib::quick_gelu); + reg.insert("GroupQueryAttention", group_query_attention::group_query_attention); reg.insert("HardSwish", |_, _| Ok((ops::nn::hard_swish().into_hir(), vec![]))); reg.insert("Mish", |_, _| Ok((expand(mish::Mish), vec![]))); reg.insert("MultiHeadAttention", multi_head_attention::multi_head_attention); diff --git a/onnx/src/ops/nn/rotary_embedding.rs b/onnx/src/ops/nn/rotary_embedding.rs index e01a13a661..07351103aa 100644 --- a/onnx/src/ops/nn/rotary_embedding.rs +++ b/onnx/src/ops/nn/rotary_embedding.rs @@ -18,7 +18,22 @@ pub fn rotary_embedding( let num_heads = node.get_attr_opt::("num_heads")?.unwrap_or(0) as usize; let rotary_embedding_dim = node.get_attr_opt::("rotary_embedding_dim")?.unwrap_or(0) as usize; - Ok((expand(RotaryEmbedding { interleaved, num_heads, rotary_embedding_dim }), vec![])) + // The com.microsoft contrib op is identical math but orders its inputs + // (input, position_ids, cos, sin) and adds `scale`/`is_packed_batching` attributes. + let microsoft = node.domain == "com.microsoft"; + if microsoft { + let scale = node.get_attr_opt::("scale")?.unwrap_or(1.0); + ensure!( + scale == 1.0, + "com.microsoft RotaryEmbedding: scale={scale} (!= 1.0) is unsupported" + ); + let packed = node.get_attr_opt::("is_packed_batching")?.unwrap_or(0); + ensure!(packed == 0, "com.microsoft RotaryEmbedding: is_packed_batching=1 is unsupported"); + } + Ok(( + expand(RotaryEmbedding { interleaved, num_heads, rotary_embedding_dim, microsoft }), + vec![], + )) } #[derive(Debug, Clone, new)] @@ -26,6 +41,8 @@ struct RotaryEmbedding { interleaved: bool, num_heads: usize, rotary_embedding_dim: usize, + /// com.microsoft contrib variant: inputs are (input, position_ids, cos, sin). + microsoft: bool, } impl Expansion for RotaryEmbedding { @@ -60,7 +77,14 @@ impl Expansion for RotaryEmbedding { let in_fact = model.outlet_fact(inputs[0])?.clone(); let in_rank = in_fact.rank(); ensure!(in_rank == 3 || in_rank == 4, "RotaryEmbedding expects rank 3 or 4, got {in_rank}"); - let has_position_ids = inputs.len() == 4; + // Input layout differs by domain: + // ai.onnx: input, cos, sin, position_ids? (position_ids optional, index 3) + // com.microsoft: input, position_ids, cos, sin (position_ids at index 1) + let (cos_idx, sin_idx, pos_idx) = if self.microsoft { + (2usize, 3usize, Some(1usize)) + } else { + (1usize, 2usize, (inputs.len() == 4).then_some(3usize)) + }; let two = 2usize.to_dim(); // 1. Normalize input to [batch, seq, heads, head_size]. @@ -107,11 +131,11 @@ impl Expansion for RotaryEmbedding { // 3. Prepare cos/sin as [B, S, 1, half] (gathering by position_ids if present). let mut prep = |tag: &str, cache: usize| -> TractResult { - let gathered = if has_position_ids { + let gathered = if let Some(pi) = pos_idx { model.wire_node( format!("{prefix}.{tag}_gather"), Gather::new(0), - &[inputs[cache], inputs[3]], + &[inputs[cache], inputs[pi]], )?[0] } else { inputs[cache] @@ -119,8 +143,8 @@ impl Expansion for RotaryEmbedding { Ok(model.wire_node(format!("{prefix}.{tag}_unsqueeze"), AxisOp::Add(2), &[gathered])? [0]) }; - let cos = prep("cos", 1)?; - let sin = prep("sin", 2)?; + let cos = prep("cos", cos_idx)?; + let sin = prep("sin", sin_idx)?; // 4. Extract the two rotated components. let (x1, x2) = if self.interleaved { diff --git a/post-release.sh b/post-release.sh index e6c49decad..e606930ec4 100755 --- a/post-release.sh +++ b/post-release.sh @@ -12,7 +12,13 @@ fi if [ -z "$VERSION" ] then - echo "Usage: $0 " + echo "Usage: $0 " + exit 1 +fi + +if ! echo "$VERSION" | grep -qE '^[0-9]+\.[0-9]+\.[0-9]+([.-][A-Za-z0-9.-]+)?$' +then + echo "Refusing version '$VERSION': must look like 0.23.0 or 0.23.0-pre (no leading 'v')." >&2 exit 1 fi @@ -24,5 +30,9 @@ do tomato set workspace.dependencies.$crate.version $VERSION Cargo.toml done +# tomato edits the manifests only; sync Cargo.lock to the bumped workspace +# versions so the committed lock matches (otherwise the next build rewrites it). +cargo update --workspace --offline + git commit . -m "post-release $VERSION" git push diff --git a/pulse-opl/Cargo.toml b/pulse-opl/Cargo.toml index 4ba12dfa53..fba02c9d2b 100644 --- a/pulse-opl/Cargo.toml +++ b/pulse-opl/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tract-pulse-opl" -version = "0.23.0-pre" +version = "0.23.1-pre" license = "MIT OR Apache-2.0" authors = ["Mathieu Poumeyrol "] description = "Tiny, no-nonsense, self contained, TensorFlow and ONNX inference" diff --git a/pulse/Cargo.toml b/pulse/Cargo.toml index 698f4061f6..504cb1973e 100644 --- a/pulse/Cargo.toml +++ b/pulse/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tract-pulse" -version = "0.23.0-pre" +version = "0.23.1-pre" license = "MIT OR Apache-2.0" authors = ["Mathieu Poumeyrol "] description = "Tiny, no-nonsense, self contained, TensorFlow and ONNX inference" diff --git a/pulse/src/blockify.rs b/pulse/src/blockify.rs index 10c4cf1f17..65f4c41e48 100644 --- a/pulse/src/blockify.rs +++ b/pulse/src/blockify.rs @@ -30,7 +30,7 @@ //! Determine the score axis the terminator contracts and translate //! it to mask frame. //! 2. **Substitute** the streaming symbol globally `T → k·S` via core's -//! `substitute_symbols`. +//! `set_symbols`. //! 3. **Rewrite** one `TypedModelPatch` per section //! (`build_section_patch`). Sections are independent so patches //! apply in sequence. A recognised section gets fully rewritten or @@ -188,7 +188,7 @@ impl ModelTransform for BlockifyTransform { let chunk_sym = model.symbols.new_with_prefix("S"); let subs: HashMap = HashMap::from([(stream_sym.clone(), chunk_sym.to_dim() * k)]); - let new_model = model.substitute_symbols(&subs)?; + let new_model = model.set_symbols(&subs)?; *model = new_model; rewrite_sections(model, &chunk_sym, k)?; model.properties.insert( diff --git a/pulse/src/lib.rs b/pulse/src/lib.rs index 73ef44d750..f8f943edac 100644 --- a/pulse/src/lib.rs +++ b/pulse/src/lib.rs @@ -293,4 +293,66 @@ mod tests { // downstream meet point. Now it should pulsify without error. let _pulse = PulsedModel::new(&model, t, &2.to_dim()).expect("pulsification"); } + + /// `MultiBroadcastTo` pulsifier baseline: a target shape that grows + /// linearly with the streaming symbol (`1 + S/2` -- the canonical + /// `shape_of(stride-2 conv)` pattern) gets the per-pulse increment + /// `P/2` after the boundary-subtract trick. Locks in the existing + /// linear contract so the non-linear fallback below cannot regress + /// it. + #[test] + fn test_multi_broadcast_to_pulsifier_linear_axis() { + use tract_pulse_opl::tract_core::ops::array::MultiBroadcastTo; + + let mut model = TypedModel::default(); + let s = model.symbols.sym("S"); + let linear: TDim = 1.to_dim() + s.to_dim() / 2; + let target_shape: ShapeFact = tvec![1.to_dim(), linear, 4.to_dim()].into(); + + let a = model.add_source("a", f32::fact(dims![1, s.to_dim(), 4].as_ref())).unwrap(); + let out = model.wire_node("bc", MultiBroadcastTo::new(target_shape), &[a]).unwrap(); + model.select_output_outlets(&out).unwrap(); + + let pulse = PulsedModel::new(&model, s, &4.to_dim()).expect("pulsification"); + let out_fact = pulse.output_fact(0).unwrap(); + // `1 + S/2` at S=P=4 is 3, at S=0 it is 1. The trick yields + // `3 - 1 = 2` per pulse; the linearity probe at S=8 gives delta + // 4, matching `2 * 2`, so we stay on the linear path. + assert_eq!( + out_fact.shape[1], + 2.to_dim(), + "linear streaming axis must keep the boundary-subtract delta; got fact: {out_fact:?}", + ); + } + + /// Non-linear target shape (`min(2, S + 2)`, which equals 2 for every + /// `S >= 0`): the boundary-subtract collapses `full - base` to 0 even + /// though the full per-pulse shape is 2. Pre-fix that produced a + /// 0-volume PulsedFact that poisoned every downstream consumer (most + /// visibly: a Scan body's State input reading the GRU `h_0` tile, + /// surfacing as `Clashing resolution for expression. 2=2 != 0` on the + /// runtime warmup turn). The fallback keeps the full value when the + /// `substitute(S→0) == substitute(S→P) == substitute(S→2P)` probe + /// confirms the axis is not actually streaming. + #[test] + fn test_multi_broadcast_to_pulsifier_non_linear_axis() { + use tract_pulse_opl::tract_core::ops::array::MultiBroadcastTo; + + let mut model = TypedModel::default(); + let s = model.symbols.sym("S"); + let non_linear: TDim = (s.to_dim() + 2.to_dim()).mini(2.to_dim()); + let target_shape: ShapeFact = tvec![1.to_dim(), non_linear, 1.to_dim()].into(); + + let a = model.add_source("a", f32::fact(dims![1, s.to_dim(), 1].as_ref())).unwrap(); + let out = model.wire_node("bc", MultiBroadcastTo::new(target_shape), &[a]).unwrap(); + model.select_output_outlets(&out).unwrap(); + + let pulse = PulsedModel::new(&model, s, &4.to_dim()).expect("pulsification"); + let out_fact = pulse.output_fact(0).unwrap(); + assert_eq!( + out_fact.shape[1], + 2.to_dim(), + "non-linear streaming axis must keep the full value, not the collapsed delta; got fact: {out_fact:?}", + ); + } } diff --git a/pulse/src/model.rs b/pulse/src/model.rs index d84ea20dd0..b825946c01 100644 --- a/pulse/src/model.rs +++ b/pulse/src/model.rs @@ -102,7 +102,7 @@ fn pulse_driven_blockify( model.symbols.add_assertion(format!("{chunk_sym} >= 0"))?; let subs: HashMap = HashMap::from([(symbol.clone(), chunk_sym.to_dim() * pulse_value)]); - *model = model.substitute_symbols(&subs)?; + *model = model.set_symbols(&subs)?; crate::blockify::rewrite_sections(model, &chunk_sym, pulse_value)?; model.properties.insert( crate::blockify::BLOCKIFY_ORIGINAL_SYMBOL.to_string(), diff --git a/pulse/src/ops/array/broadcast.rs b/pulse/src/ops/array/broadcast.rs index e3cb826a43..78c80c7340 100644 --- a/pulse/src/ops/array/broadcast.rs +++ b/pulse/src/ops/array/broadcast.rs @@ -23,14 +23,7 @@ fn pulsify( .enumerate() .map(|(i, dim)| { if i == axis { - // Remove the constant boundary term so that per-pulse output size - // matches the actual pulsed output of any upstream strided conv. - // E.g. shape_of(stride-2 conv) = 1 + S/2: - // substitute(S→P) = 1 + P/2 (wrong) - // substitute(S→P) - substitute(S→0) = P/2 (correct) - let full = dim.substitute(symbol, pulse)?; - let base = dim.substitute(symbol, &TDim::Val(0))?; - Ok(full - base) + pulsified_stream_axis_dim(dim, symbol, pulse) } else { dim.substitute(symbol, pulse) } @@ -45,6 +38,54 @@ fn pulsify( } } +/// Compute the per-pulse size for a `MultiBroadcastTo` target axis whose +/// shape mentions the streaming symbol. +/// +/// The canonical pattern emitted by ONNX `Expand` / `BroadcastTo` against +/// a `shape_of(streaming)` chain is a *linear* dim of the form +/// `pulse_growth(S) + boundary` with `pulse_growth(0) = 0`, e.g. a +/// stride-2 conv's output length `1 + S/2`. For that pattern the per-pulse +/// increment is `dim(P) - dim(0)` and we use it as the pulse-axis size. +/// +/// When the expression is constant over `[0, P]` or non-linear in `S`, +/// the same subtraction can collapse `full - base` to `0` while `full` +/// itself is a positive valid shape. That happens for chunked-batch +/// expressions like `1 + -1*min(2, -1+(8·S)/5) + (8·S)/5` (which equals +/// `max(2, (8·S)/5 - 1)`), where every `S ∈ {0, P, 2P}` resolves to the +/// same lower bound. The consumer of such an axis (a Scan body state +/// init, an elementwise meet point) reads the *full* per-pulse shape, +/// not an empty delta — emitting a `0`-volume PulsedFact poisons every +/// downstream fact. +/// +/// Heuristic: probe at `S=0`, `S=P`, and `S=2P`. Use the linear +/// subtraction iff the delta is strictly positive and `delta(2P) == +/// 2·delta(P)` (provably linear over the probe interval). Otherwise +/// fall back to `dim(P)`. +fn pulsified_stream_axis_dim(dim: &TDim, symbol: &Symbol, pulse: &TDim) -> TractResult { + let full = dim.substitute(symbol, pulse)?; + let base = dim.substitute(symbol, &TDim::Val(0))?; + let delta = full.clone() - base.clone(); + // Constant on `[0, P]` — this axis is not actually streaming on this + // pulse window. Use the full value so downstream facts stay + // non-degenerate. + if delta == 0.to_dim() { + return Ok(full); + } + // Confirm linearity by sampling at `2P`. Only worthwhile when `P` is + // a concrete positive integer; for symbolic `pulse` the trick falls + // back to the existing behavior (treat as linear). + if let Some(pulse_v) = pulse.as_i64() + && pulse_v > 0 + { + let double = dim.substitute(symbol, &TDim::Val(pulse_v * 2))?; + let delta_double = double - base; + if delta_double != delta.clone() * 2 { + return Ok(full); + } + } + Ok(delta) +} + /// Concat with pulse along concat axis #[derive(Debug, Clone, Hash, PartialEq, Eq)] struct PulsedMultibroadcastTo { diff --git a/pulse/src/ops/fft.rs b/pulse/src/ops/fft.rs index c7d809f5d7..976b8c76d4 100644 --- a/pulse/src/ops/fft.rs +++ b/pulse/src/ops/fft.rs @@ -3,9 +3,15 @@ use crate::internal::*; use tract_core::ops::fft::Stft; use tract_pulse_opl::ops::Delay; -register_all!(Stft: pulsify); +register_all!(Stft: stft_pulsify); -fn pulsify( +// `Fft` needs no dedicated pulsifier. Its `axes_mapping` declares the +// FFT axis and the trailing complex axis as input-only / output-only +// (they do not map 1-to-1), so the generic pulse fallback in `model.rs` +// automatically refuses to track a streaming axis through them and +// handles streaming on any genuine batch axis with `PulseWrappingOp`. + +fn stft_pulsify( op: &Stft, _source: &TypedModel, node: &TypedNode, diff --git a/pulse/src/ops/scan.rs b/pulse/src/ops/scan.rs index 1db251f53f..bf2a185300 100644 --- a/pulse/src/ops/scan.rs +++ b/pulse/src/ops/scan.rs @@ -40,6 +40,13 @@ fn pulsify( let first_scan_axis = target.outlet_fact(pulse_inputs[first_scan_slot])?.stream.as_ref().unwrap().axis; let scan_axis = axes_mapping.axis((InOut::In(first_scan_slot), first_scan_axis))?; + // Bake the same `symbol -> pulse` substitution the outer pulsifier just + // applied to the wire facts into the Scan body and output_mapping. Without + // it, post-pulse declutter folds outer and body shape expressions in + // independent scopes and lands on different literals, producing a silent + // drift between outer-input facts and body source facts on every dim + // that mentioned the stream symbol. + let subs: HashMap = HashMap::from([(symbol.clone(), pulse.clone())]); if first_scan_axis == op.input_mapping[first_scan_slot].as_scan().unwrap().axis { let mut op = op.clone(); op.skip = target.outlet_fact(pulse_inputs[first_scan_slot])?.stream.as_ref().unwrap().delay; @@ -48,10 +55,19 @@ fn pulsify( om.full_dim_hint = None; } } + op.body = op.body.set_symbols(&subs)?; + for om in op.output_mapping.iter_mut() { + *om = om.set_symbols(&subs)?; + } Ok(Some(target.wire_node(&*node.name, op, &pulse_inputs)?)) } else if scan_axis.outputs.iter().all(|x| x.len() == 1) { let body = PulsedModel::new(&op.body, symbol.clone(), pulse)?.into_typed()?; - let mut new_op = Scan::new(body, op.input_mapping.clone(), op.output_mapping.clone(), 0)?; + let output_mapping = op + .output_mapping + .iter() + .map(|om| om.set_symbols(&subs)) + .collect::>>()?; + let mut new_op = Scan::new(body, op.input_mapping.clone(), output_mapping, 0)?; new_op.reset_every_turn = true; target.wire_node(&node.name, new_op, &pulse_inputs).map(Some) } else { diff --git a/release.sh b/release.sh index 5692a80405..c23ed77b7d 100755 --- a/release.sh +++ b/release.sh @@ -13,11 +13,17 @@ VERSION=$2 if [ -z "$VERSION" ] then - echo "Usage: $0 " + echo "Usage: $0 " echo crates order is: $ALL_CRATES_PATH exit 1 fi +if ! echo "$VERSION" | grep -qE '^[0-9]+\.[0-9]+\.[0-9]+([.-][A-Za-z0-9.-]+)?$' +then + echo "Refusing version '$VERSION': must look like 0.23.0 or 0.23.0-pre (no leading 'v')." >&2 + exit 1 +fi + set -ex if [ "$CRATE_PATH" = "all" ] diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 0000000000..73cb934de4 --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,3 @@ +[toolchain] +channel = "stable" +components = ["rustfmt", "clippy"] diff --git a/tensorflow/Cargo.toml b/tensorflow/Cargo.toml index b809644cd7..7783b0ce44 100644 --- a/tensorflow/Cargo.toml +++ b/tensorflow/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tract-tensorflow" -version = "0.23.0-pre" +version = "0.23.1-pre" authors = ["Mathieu Poumeyrol "] license = "MIT OR Apache-2.0" description = "Tiny, no-nonsense, self contained, TensorFlow and ONNX inference" @@ -22,7 +22,6 @@ log.workspace = true memmap2.workspace = true prost.workspace = true prost-types.workspace = true -tensorflow = { workspace = true, optional = true } tract-hir.workspace = true tract-pulse.workspace = true @@ -30,9 +29,6 @@ tract-pulse.workspace = true # protobuf-src = "1.0.5+3.19.3" # prost-build = "0.14" -[features] -conform = [ "tensorflow" ] - [dev-dependencies] criterion.workspace = true env_logger.workspace = true diff --git a/tensorflow/src/conform/mod.rs b/tensorflow/src/conform/mod.rs deleted file mode 100644 index 8ec61a57d4..0000000000 --- a/tensorflow/src/conform/mod.rs +++ /dev/null @@ -1,48 +0,0 @@ -#![allow(unused)] -#![allow(deprecated)] -#![allow(non_snake_case)] - -pub mod tf; - -use crate::tfpb; -use crate::tfpb::tensorflow::tensor_shape_proto::Dim; -use crate::tfpb::tensorflow::{DataType, TensorProto, TensorShapeProto}; -use std::convert::TryInto; -use tract_hir::internal::*; - -pub fn placeholder>>( - name: &str, - t: DataType, - shape: Shape, -) -> tfpb::tensorflow::NodeDef { - let mut node = tfpb::node().name(name).op("Placeholder").attr("dtype", t); - if let Some(shape) = shape.into() { - node = node.attr("shape", shape) - } - node -} - -pub fn tensor_shape(dims: &[usize]) -> TensorShapeProto { - TensorShapeProto { - dim: dims.iter().map(|&d| Dim { size: d as i64, name: String::new() }).collect(), - unknown_rank: false, - } -} - -pub fn const_f32(name: &str, t: &Tensor) -> tfpb::tensorflow::NodeDef { - let tf: TensorProto = t.cast_to::().unwrap().as_ref().try_into().unwrap(); - tfpb::node().name(name).op("Const").attr("dtype", DataType::DtFloat).attr("value", tf) -} - -pub fn placeholder_f32(name: &str) -> tfpb::tensorflow::NodeDef { - placeholder(name, DataType::DtFloat, None) -} - -pub fn const_i32(name: &str, t: &Tensor) -> tfpb::tensorflow::NodeDef { - let tf: TensorProto = t.cast_to::().unwrap().as_ref().try_into().unwrap(); - tfpb::node().name(name).op("Const").attr("dtype", DataType::DtInt32).attr("value", tf) -} - -pub fn placeholder_i32(name: &str) -> tfpb::tensorflow::NodeDef { - placeholder(name, DataType::DtInt32, None) -} diff --git a/tensorflow/src/conform/tf.rs b/tensorflow/src/conform/tf.rs deleted file mode 100644 index bad6781e81..0000000000 --- a/tensorflow/src/conform/tf.rs +++ /dev/null @@ -1,272 +0,0 @@ -#![allow(dead_code)] - -use std::{fs, path}; - -use tensorflow as tf; -use tensorflow::DataType; -use tensorflow::FetchToken; -use tensorflow::Graph; -use tensorflow::Session; -use tensorflow::SessionRunArgs; - -use tract_hir::internal::*; -use tract_ndarray::prelude::*; - -use std::collections::HashMap; -use std::collections::HashSet; - -pub struct Tensorflow { - graph: Graph, -} - -pub fn version() -> String { - tf::version().unwrap() -} - -pub fn for_path>(p: P) -> TractResult { - use std::io::Read; - let mut model = vec![]; - fs::File::open(p)?.read_to_end(&mut model)?; - for_slice(&*model) -} - -pub fn for_slice(buf: &[u8]) -> TractResult { - let mut graph = Graph::new(); - graph.import_graph_def(buf, &::tensorflow::ImportGraphDefOptions::new())?; - Ok(Tensorflow { graph }) -} - -enum TensorHolder { - Bool(tf::Tensor), - F16(tf::Tensor<::tensorflow::BFloat16>), - F32(tf::Tensor), - F64(tf::Tensor), - U8(tf::Tensor), - U16(tf::Tensor), - I8(tf::Tensor), - I16(tf::Tensor), - I32(tf::Tensor), - I64(tf::Tensor), - String(tf::Tensor), -} - -impl TensorHolder { - fn to_tensor(m: ArrayD) -> tf::Tensor { - let dims: Vec = m.shape().iter().map(|d| *d as _).collect(); - let mut tensor = tf::Tensor::::new(&*dims); - tensor.copy_from_slice(m.as_slice().unwrap()); - tensor - } -} - -impl From for TensorHolder { - fn from(m: Tensor) -> TensorHolder { - match m.datum_type() { - DatumType::Bool => TensorHolder::Bool(Self::to_tensor(m.into_plain_array().unwrap())), - DatumType::F16 => unimplemented!(), - DatumType::F32 => TensorHolder::F32(Self::to_tensor(m.into_plain_array().unwrap())), - DatumType::F64 => TensorHolder::F64(Self::to_tensor(m.into_plain_array().unwrap())), - DatumType::I8 => TensorHolder::I8(Self::to_tensor(m.into_plain_array().unwrap())), - DatumType::I16 => TensorHolder::I16(Self::to_tensor(m.into_plain_array().unwrap())), - DatumType::I32 => TensorHolder::I32(Self::to_tensor(m.into_plain_array().unwrap())), - DatumType::I64 => TensorHolder::I64(Self::to_tensor(m.into_plain_array().unwrap())), - DatumType::U8 => TensorHolder::U8(Self::to_tensor(m.into_plain_array().unwrap())), - DatumType::U16 => TensorHolder::U16(Self::to_tensor(m.into_plain_array().unwrap())), - DatumType::U32 => TensorHolder::U16(Self::to_tensor(m.into_plain_array().unwrap())), - DatumType::U64 => TensorHolder::U16(Self::to_tensor(m.into_plain_array().unwrap())), - DatumType::QU8(_) => TensorHolder::U8(Self::to_tensor(m.into_plain_array().unwrap())), - DatumType::QI8(_) => TensorHolder::I8(Self::to_tensor(m.into_plain_array().unwrap())), - DatumType::QI32(_) => TensorHolder::I32(Self::to_tensor(m.into_plain_array().unwrap())), - #[cfg(feature = "complex")] - DatumType::ComplexI16 => unimplemented!(), - #[cfg(feature = "complex")] - DatumType::ComplexI32 => unimplemented!(), - #[cfg(feature = "complex")] - DatumType::ComplexI64 => unimplemented!(), - #[cfg(feature = "complex")] - DatumType::ComplexF16 => unimplemented!(), - #[cfg(feature = "complex")] - DatumType::ComplexF32 => unimplemented!(), - #[cfg(feature = "complex")] - DatumType::ComplexF64 => unimplemented!(), - DatumType::TDim => { - let dims = m.to_plain_array_view::().unwrap(); - if let Ok(dims) = dims.iter().map(|d| d.to_i32()).collect::>>() { - TensorHolder::I32(Self::to_tensor(arr1(&dims).into_dyn())) - } else { - panic!("Streaming used in tensorflow settings") - } - } - DatumType::String => { - TensorHolder::String(Self::to_tensor(m.into_plain_array().unwrap())) - } - DatumType::Blob => TensorHolder::String(Self::to_tensor(m.into_plain_array().unwrap())), - _ => panic!("No support for {:?} DT in tensorflow", m.datum_type()), - } - } -} - -fn tensor_to_array(tensor: &tf::Tensor) -> TractResult> { - let shape: Vec = tensor.dims().iter().map(|d| *d as _).collect(); - Ok(Array::from(tensor.into_iter().cloned().collect::>()).into_shape_with_order(shape)?) -} - -impl Tensorflow { - /// Executes the graph in one batch. - pub fn run( - &mut self, - inputs: Vec<(&str, Tensor)>, - output_name: &str, - ) -> TractResult> { - let tensors: Vec<(&str, TensorHolder)> = - inputs.into_iter().map(|(name, mat)| (name, mat.into())).collect(); - - let mut step = SessionRunArgs::new(); - for t in &tensors { - let op = self.graph.operation_by_name_required(t.0)?; - match t.1 { - TensorHolder::Bool(ref it) => step.add_feed(&op, 0, &it), - TensorHolder::U8(ref it) => step.add_feed(&op, 0, &it), - TensorHolder::U16(ref it) => step.add_feed(&op, 0, &it), - TensorHolder::I8(ref it) => step.add_feed(&op, 0, &it), - TensorHolder::I16(ref it) => step.add_feed(&op, 0, &it), - TensorHolder::I32(ref it) => step.add_feed(&op, 0, &it), - TensorHolder::I64(ref it) => step.add_feed(&op, 0, &it), - TensorHolder::F16(_) => unimplemented!(), - TensorHolder::F32(ref it) => step.add_feed(&op, 0, &it), - TensorHolder::F64(ref it) => step.add_feed(&op, 0, &it), - TensorHolder::String(ref it) => step.add_feed(&op, 0, &it), - } - } - - let op = &self.graph.operation_by_name_required(output_name)?; - let tokens = - (0..op.num_outputs()).map(|ix| step.request_fetch(&op, ix as i32)).collect::>(); - - let mut session = Session::new(&::tensorflow::SessionOptions::new(), &self.graph)?; - session.run(&mut step)?; - - tokens - .into_iter() - .enumerate() - .map(|(ix, tok)| { - let output_type = - &self.graph.operation_by_name_required(&output_name)?.output_type(ix); - convert_output(&mut step, output_type, tok) - }) - .collect() - } - - /// Executes the graph in one batch, and returns the output for every node but the inputs. - pub fn run_get_many<'a>( - &mut self, - inputs: Vec<(&'a str, Tensor)>, - targets: Vec<&'a str>, - ) -> TractResult>> { - let mut input_pairs: Vec<(&str, TensorHolder)> = Vec::new(); - let mut excluded = HashSet::new(); - - for (name, mat) in inputs { - input_pairs.push((name, mat.into())); - excluded.insert(name.to_string()); - } - - let mut step = SessionRunArgs::new(); - for t in &input_pairs { - let op = self.graph.operation_by_name_required(t.0)?; - match t.1 { - TensorHolder::Bool(ref it) => step.add_feed(&op, 0, &it), - TensorHolder::U8(ref it) => step.add_feed(&op, 0, &it), - TensorHolder::U16(ref it) => step.add_feed(&op, 0, &it), - TensorHolder::I8(ref it) => step.add_feed(&op, 0, &it), - TensorHolder::I16(ref it) => step.add_feed(&op, 0, &it), - TensorHolder::I32(ref it) => step.add_feed(&op, 0, &it), - TensorHolder::I64(ref it) => step.add_feed(&op, 0, &it), - TensorHolder::F16(_) => unimplemented!(), - TensorHolder::F32(ref it) => step.add_feed(&op, 0, &it), - TensorHolder::F64(ref it) => step.add_feed(&op, 0, &it), - TensorHolder::String(ref it) => step.add_feed(&op, 0, &it), - } - } - - let mut tokens = HashMap::new(); - trace!("Targets: {:?}", targets); - for name in targets { - if excluded.contains(name) { - continue; - } - - if let Some(operation) = self.graph.operation_by_name(name)? { - // switch only computes one of its outputs. tf explodes during - // the call to run() if we registers them - if operation.op_type()? == "Switch" { - continue; - } - - // this one pretends to have 5 outputs, but has only one - if operation.op_type()? == "FusedBatchNorm" { - continue; - } - - let outputs = (0..operation.num_outputs()) - .map(|ix| step.request_fetch(&operation, ix as i32)) - .collect::>(); - - tokens.insert(name, outputs); - } - } - trace!("Generated all output tokens"); - trace!("{:?}", tokens); - - // Execute the graph using tensorflow. - let mut session = Session::new(&::tensorflow::SessionOptions::new(), &self.graph)?; - session.run(&mut step)?; - trace!("Tensorflow ran succesfully"); - - // Return the output for every node. - let mut outputs = HashMap::new(); - for (name, tokens) in tokens { - let tensors = tokens - .iter() - .enumerate() - .map(|(ix, tok)| { - let output_type = - &self.graph.operation_by_name_required(&name)?.output_type(ix); - convert_output(&mut step, output_type, *tok) - }) - .collect::>>()?; - outputs.insert(name, tensors); - } - - Ok(outputs) - } -} - -/// Converts the output of a Tensorflow node into a Tensor. -fn convert_output( - step: &mut SessionRunArgs, - output_type: &DataType, - output: FetchToken, -) -> TractResult { - macro_rules! convert { - ($dt:ident) => { - match step.fetch(output) { - Err(r) => Err(r)?, - Ok(output) => tensor_to_array::<$dt>(&output)?.into(), - } - }; - }; - - let tract_tensor = match output_type { - DataType::Bool => convert!(bool), - DataType::Float => convert!(f32), - DataType::Double => convert!(f64), - DataType::UInt8 => convert!(u8), - DataType::Int8 => convert!(i8), - DataType::Int32 => convert!(i32), - DataType::Int64 => convert!(i64), - t => bail!("Missing conversion for tensorflow to tract (type: {:?})", t), - }; - - Ok(tract_tensor) -} diff --git a/tensorflow/src/lib.rs b/tensorflow/src/lib.rs index 30e17e08ac..752d746b77 100644 --- a/tensorflow/src/lib.rs +++ b/tensorflow/src/lib.rs @@ -43,13 +43,8 @@ extern crate log; extern crate env_logger; extern crate prost; extern crate prost_types; -#[cfg(feature = "conform")] -extern crate tensorflow; pub extern crate tract_hir; -#[cfg(feature = "conform")] -pub mod conform; - pub mod model; pub mod ops; pub mod tensor; diff --git a/test-rt/test-cuda/Cargo.toml b/test-rt/test-cuda/Cargo.toml index 962080a06e..313572f014 100644 --- a/test-rt/test-cuda/Cargo.toml +++ b/test-rt/test-cuda/Cargo.toml @@ -27,7 +27,7 @@ log.workspace = true tract-core.workspace = true tract-onnx-opl.workspace = true infra = { path = "../infra" } -tract-cuda.workspace = true +tract-cuda = { workspace = true, features = ["cuda-13000"] } tract-gpu.workspace = true suite-onnx = { path = "../suite-onnx" } suite-unit = { path = "../suite-unit" } diff --git a/tflite/Cargo.toml b/tflite/Cargo.toml index d323e8a5e5..e2e9765953 100644 --- a/tflite/Cargo.toml +++ b/tflite/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tract-tflite" -version = "0.23.0-pre" +version = "0.23.1-pre" authors = ["Mathieu Poumeyrol "] license = "MIT OR Apache-2.0" description = "Tiny, no-nonsense, self contained, TensorFlow and ONNX inference" diff --git a/transformers/Cargo.toml b/transformers/Cargo.toml index eb63749c98..4b9ad6f788 100644 --- a/transformers/Cargo.toml +++ b/transformers/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tract-transformers" -version = "0.23.0-pre" +version = "0.23.1-pre" license = "MIT OR Apache-2.0" authors = ["Mathieu Poumeyrol ", "Louis Chouraki "] description = "Tiny, no-nonsense, self contained, TensorFlow and ONNX inference" @@ -17,3 +17,9 @@ maintenance = { status = "actively-developed" } [dependencies] float-ord.workspace = true tract-nnef.workspace = true + +# Used to spread FlashSDPA's independent attention heads across cores (rayon's +# global pool). Not available on wasm32-unknown-unknown (no thread spawn), where +# the op stays single-threaded. +[target.'cfg(not(target_family = "wasm"))'.dependencies] +rayon.workspace = true diff --git a/transformers/src/lib.rs b/transformers/src/lib.rs index fc83923999..74cbd7d965 100644 --- a/transformers/src/lib.rs +++ b/transformers/src/lib.rs @@ -21,6 +21,7 @@ pub fn register(registry: &mut Registry) { ops::scaled_masked_softmax::register(registry); ops::sdpa::register(registry); ops::dyn_kv_cache::register(registry); + ops::window_kv_cache::register(registry); } pub trait WithTractTransformers { diff --git a/transformers/src/ops/diag_gather.rs b/transformers/src/ops/diag_gather.rs index ad8f61b16b..678831dc23 100644 --- a/transformers/src/ops/diag_gather.rs +++ b/transformers/src/ops/diag_gather.rs @@ -142,7 +142,7 @@ impl TypedOp for DiagGather { Ok(Some(tvec![Some(input_roi)])) } - fn substitute_symbols( + fn set_symbols( &self, _source: &TypedModel, node: &TypedNode, diff --git a/transformers/src/ops/flash_sdpa.rs b/transformers/src/ops/flash_sdpa.rs index 08e0856c14..621f7e79ad 100644 --- a/transformers/src/ops/flash_sdpa.rs +++ b/transformers/src/ops/flash_sdpa.rs @@ -174,108 +174,292 @@ impl FlashSdpaOp { v: ArrayView4, mask: Option>, ) -> Array4 { - // Explicit dimensions + self.flash_attention_gqa_impl(q, k, v, mask, false) + } + + fn flash_attention_gqa_impl( + &self, + q: ArrayView4, + k: ArrayView4, + v: ArrayView4, + mask: Option>, + pv_strided: bool, + ) -> Array4 { let (batch_size, num_q_heads, q_len, head_dim) = q.dim(); - let (_, num_kv_heads, kv_len, _) = k.dim(); - let scale = self.scale.unwrap_or((head_dim as f32).recip().sqrt()); + let (_, num_kv_heads, _, _) = k.dim(); let group_size = num_q_heads / num_kv_heads; + let mut out = Array4::::zeros((batch_size, num_q_heads, q_len, head_dim)); + + // One independent task per (batch, q-head). Heads share only read-only Q/K/V + // and write disjoint output slices, so this is embarrassingly parallel. + let tasks: Vec<(usize, usize)> = + (0..batch_size).flat_map(|b| (0..num_q_heads).map(move |qh| (b, qh))).collect(); + let compute = |&(b, qh): &(usize, usize)| { + self.attend_one_head(b, qh, qh / group_size, q, k, v, mask.as_ref(), pv_strided) + }; + + let results: Vec> = if Self::should_thread(tasks.len()) { + #[cfg(not(target_family = "wasm"))] + { + use rayon::prelude::*; + tasks.par_iter().map(compute).collect() + } + #[cfg(target_family = "wasm")] + { + tasks.iter().map(compute).collect() + } + } else { + tasks.iter().map(compute).collect() + }; + + for (&(b, qh), head_out) in tasks.iter().zip(results) { + out.slice_mut(s!(b, qh, .., ..)).assign(&head_out); + } + out + } + + /// Whether to spread the per-head tasks across cores (rayon's global pool). + /// Off for a single task, on wasm (no thread spawn), or when `TRACT_FLASH_SDPA_ST` is set. + fn should_thread(num_tasks: usize) -> bool { + num_tasks > 1 + && cfg!(not(target_family = "wasm")) + && std::env::var_os("TRACT_FLASH_SDPA_ST").is_none() + } + + /// Flash-attention for a single (batch, q-head) pair. Returns O of shape + /// `[q_len, head_dim]`. Pure function of the read-only Q/K/V/mask views, so it + /// is safe to call concurrently for distinct heads. + #[allow(clippy::too_many_arguments)] + fn attend_one_head( + &self, + b: usize, + qh: usize, + kvh: usize, + q: ArrayView4, + k: ArrayView4, + v: ArrayView4, + mask: Option<&ArrayView4>, + pv_strided: bool, + ) -> Array2 { + let (_, _, q_len, head_dim) = q.dim(); + let kv_len = k.dim().2; + let scale = self.scale.unwrap_or((head_dim as f32).recip().sqrt()); + let block_kv_len = 32; let block_q_len = 32; - let mut out = Array4::::zeros((batch_size, num_q_heads, q_len, head_dim)); + let mb = b.min(mask.map(|m| m.shape()[0] - 1).unwrap_or(0)); + let mh = qh.min(mask.map(|m| m.shape()[1] - 1).unwrap_or(0)); - for b in 0..batch_size { - let mb = b.min(mask.as_ref().map(|m| m.shape()[0] - 1).unwrap_or(0)); - for kvh in 0..num_kv_heads { - for g in 0..group_size { - let qh = kvh * group_size + g; - let mh = qh.min(mask.as_ref().map(|m| m.shape()[1] - 1).unwrap_or(0)); - let mut l = vec![0f32; q_len]; - let mut m = vec![f32::NEG_INFINITY; q_len]; - for kbix in 0..kv_len.div_ceil(block_kv_len) { - for qbix in 0..q_len.div_ceil(block_q_len) { - let kv_range = - (kbix * block_kv_len)..((kbix + 1) * block_kv_len).min(kv_len); - let q_range = - (qbix * block_q_len)..((qbix + 1) * block_q_len).min(q_len); - if let Some(mask) = &mask { - if mask - .slice(s!(mb, mh, q_range.clone(), kv_range.clone())) - .iter() - .all(|x| *x < -65503.0) - { - continue; - } - } - let m = &mut m[q_range.clone()]; - let l = &mut l[q_range.clone()]; - let qblock: ArrayView2 = q.slice(s!(b, qh, q_range.clone(), ..)); - let kblock: ArrayView2 = k.slice(s!(b, kvh, kv_range.clone(), ..)); - let vblock: ArrayView2 = v.slice(s!(b, kvh, kv_range.clone(), ..)); - let mut oblock: ArrayViewMut2 = - out.slice_mut(s!(b, qh, q_range.clone(), ..)); - // Sij <- QiKTj - let mut s = qblock.dot(&kblock.t()) * scale; - if let Some(mask) = &mask { - s += &mask.slice(s!(mb, mh, q_range.clone(), kv_range.clone())); - } else if self.causal { - let mask = Array2::from_elem( - (q_range.len(), kv_range.len()), - f32::NEG_INFINITY, - ); - let mask = mask.triu( - q_range.start as isize - - kv_range.start as isize - - q_len as isize - + kv_len as isize - + 1, - ); - s += &mask; - }; - let tile_m: Vec = s - .rows() - .into_iter() - .map(|row| { - row.iter().copied().map(float_ord::FloatOrd).max().unwrap().0 - }) - .collect_vec(); - for (row_ix, max) in tile_m.iter().enumerate() { - if max.is_finite() { - s.row_mut(row_ix).iter_mut().for_each(|x| *x -= max); - } - } - // Sij <- exp(Sij * scale - max_of_row) - s.mapv_inplace(f32::exp); - let tile_l = s - .sum_axis(tract_ndarray::Axis(1)) - .insert_axis(tract_ndarray::Axis(1)); - // m_new = max(maxes, row_maxs) - let m_new = - (0..q_range.len()).map(|i| m[i].max(tile_m[i])).collect_vec(); - // l_new = exp(m[i] - m_new[i]) * l[i] - exp(tile_m[i] - m_new[i]) * tile_l[i] - let l_new = (0..q_range.len()) - .map(|i| { - (m[i] - m_new[i]).exp() * l[i] - + (tile_m[i] - m_new[i]).exp() * tile_l[(i, 0)] - }) - .collect_vec(); - for i in 0..q_range.len() { - let r_l_new = l_new[i].recip(); - let mul_o = ((m[i] - m_new[i]).exp()) * l[i] * r_l_new; - let mul_sv = ((tile_m[i] - m_new[i]).exp()) * r_l_new; - for j in 0..head_dim { - let sv = s.row(i).dot(&vblock.column(j)); - oblock[(i, j)] = oblock[(i, j)] * mul_o + sv * mul_sv; - } - } - l.copy_from_slice(&l_new); - m.copy_from_slice(&m_new); + let mut out = Array2::::zeros((q_len, head_dim)); + let mut l = vec![0f32; q_len]; + let mut m = vec![f32::NEG_INFINITY; q_len]; + for kbix in 0..kv_len.div_ceil(block_kv_len) { + for qbix in 0..q_len.div_ceil(block_q_len) { + let kv_range = (kbix * block_kv_len)..((kbix + 1) * block_kv_len).min(kv_len); + let q_range = (qbix * block_q_len)..((qbix + 1) * block_q_len).min(q_len); + if let Some(mask) = mask { + if mask + .slice(s!(mb, mh, q_range.clone(), kv_range.clone())) + .iter() + .all(|x| *x < -65503.0) + { + continue; + } + } + let m = &mut m[q_range.clone()]; + let l = &mut l[q_range.clone()]; + let qblock: ArrayView2 = q.slice(s!(b, qh, q_range.clone(), ..)); + let kblock: ArrayView2 = k.slice(s!(b, kvh, kv_range.clone(), ..)); + let vblock: ArrayView2 = v.slice(s!(b, kvh, kv_range.clone(), ..)); + let mut oblock: ArrayViewMut2 = out.slice_mut(s!(q_range.clone(), ..)); + // Sij <- QiKTj + let mut s = qblock.dot(&kblock.t()) * scale; + if let Some(mask) = mask { + s += &mask.slice(s!(mb, mh, q_range.clone(), kv_range.clone())); + } else if self.causal { + let mask = + Array2::from_elem((q_range.len(), kv_range.len()), f32::NEG_INFINITY); + let mask = mask.triu( + q_range.start as isize - kv_range.start as isize - q_len as isize + + kv_len as isize + + 1, + ); + s += &mask; + }; + let tile_m: Vec = s + .rows() + .into_iter() + .map(|row| row.iter().copied().map(float_ord::FloatOrd).max().unwrap().0) + .collect_vec(); + for (row_ix, max) in tile_m.iter().enumerate() { + if max.is_finite() { + s.row_mut(row_ix).iter_mut().for_each(|x| *x -= max); + } + } + // Sij <- exp(Sij * scale - max_of_row) + s.mapv_inplace(f32::exp); + let tile_l = s.sum_axis(tract_ndarray::Axis(1)).insert_axis(tract_ndarray::Axis(1)); + // m_new = max(maxes, row_maxs) + let m_new = (0..q_range.len()).map(|i| m[i].max(tile_m[i])).collect_vec(); + // l_new = exp(m[i] - m_new[i]) * l[i] + exp(tile_m[i] - m_new[i]) * tile_l[i] + let l_new = (0..q_range.len()) + .map(|i| { + (m[i] - m_new[i]).exp() * l[i] + + (tile_m[i] - m_new[i]).exp() * tile_l[(i, 0)] + }) + .collect_vec(); + // P·V as one contiguous tile GEMM (s: [q,kv] · vblock: [kv,d]) instead of + // `head_dim` strided column dots; then rescale each row with the online-softmax + // factors. `pv_strided` keeps the old per-column path for A/B benching only. + let sv_tile = if pv_strided { None } else { Some(s.dot(&vblock)) }; + for i in 0..q_range.len() { + let r_l_new = l_new[i].recip(); + let mul_o = ((m[i] - m_new[i]).exp()) * l[i] * r_l_new; + let mul_sv = ((tile_m[i] - m_new[i]).exp()) * r_l_new; + if let Some(sv_tile) = &sv_tile { + for j in 0..head_dim { + oblock[(i, j)] = oblock[(i, j)] * mul_o + sv_tile[(i, j)] * mul_sv; + } + } else { + for j in 0..head_dim { + let sv = s.row(i).dot(&vblock.column(j)); + oblock[(i, j)] = oblock[(i, j)] * mul_o + sv * mul_sv; } } } + l.copy_from_slice(&l_new); + m.copy_from_slice(&m_new); } } out } } + +#[cfg(test)] +mod tests { + use super::*; + use tract_nnef::tract_ndarray::Array4; + + fn rng(n: usize, seed: u64) -> Vec { + let mut s = seed; + (0..n) + .map(|_| { + s = s.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407); + ((s >> 40) as f32 / (1u64 << 24) as f32) - 0.5 + }) + .collect() + } + + #[test] + fn flash_sdpa_pv_matches_naive() -> TractResult<()> { + let (b, h, sq, sk, d) = (1usize, 2, 16, 24, 8); // non-square, multi-head + let scale = 1.0f32 / (d as f32).sqrt(); + let qv = rng(b * h * sq * d, 1); + let kv = rng(b * h * sk * d, 2); + let vv = rng(b * h * sk * d, 3); + let q = Array4::from_shape_vec((b, h, sq, d), qv.clone())?; + let k = Array4::from_shape_vec((b, h, sk, d), kv.clone())?; + let v = Array4::from_shape_vec((b, h, sk, d), vv.clone())?; + let op = FlashSdpaOp { causal: false, scale: None }; + let got = op.flash_attention_gqa(q.view(), k.view(), v.view(), None); + let gv: Vec = got.iter().copied().collect(); + let mut want = vec![0f32; b * h * sq * d]; + for bh in 0..b * h { + for i in 0..sq { + let mut sc = vec![0f32; sk]; + for j in 0..sk { + let mut a = 0f32; + for dd in 0..d { + a += qv[(bh * sq + i) * d + dd] * kv[(bh * sk + j) * d + dd]; + } + sc[j] = a * scale; + } + let m = sc.iter().copied().fold(f32::MIN, f32::max); + let mut sum = 0f32; + for x in sc.iter_mut() { + *x = (*x - m).exp(); + sum += *x; + } + for x in sc.iter_mut() { + *x /= sum; + } + for e in 0..d { + let mut a = 0f32; + for j in 0..sk { + a += sc[j] * vv[(bh * sk + j) * d + e]; + } + want[(bh * sq + i) * d + e] = a; + } + } + } + let max_abs = gv.iter().zip(want.iter()).map(|(&g, &w)| (g - w).abs()).fold(0f32, f32::max); + println!("flash_sdpa PV-gemm max_abs={max_abs:.6}"); + ensure!(max_abs < 1e-5, "flash_sdpa PV mismatch: {max_abs}"); + Ok(()) + } + + // A/B the P·V step (tile GEMM vs strided per-column dot) across the microbench + // shape AND the real Llama-3.2-1B PP512 prefill shape (causal, GQA 32q/8kv). + // cargo test -p tract-transformers --release bench_flash_sdpa_pv -- --ignored --nocapture + #[test] + #[ignore] + fn bench_flash_sdpa_pv() -> TractResult<()> { + use std::time::Instant; + // (label, hq, hkv, sq, sk, d, causal) + let shapes = [ + ( + "microbench non-causal Hq8 Hkv8 S512 D64", + 8usize, + 8usize, + 512usize, + 512usize, + 64usize, + false, + ), + ("llama-3.2-1B causal Hq32 Hkv8 S512 D64", 32, 8, 512, 512, 64, true), + ]; + for (label, hq, hkv, sq, sk, d, causal) in shapes { + let q = Array4::from_shape_vec((1, hq, sq, d), rng(hq * sq * d, 1))?; + let k = Array4::from_shape_vec((1, hkv, sk, d), rng(hkv * sk * d, 2))?; + let v = Array4::from_shape_vec((1, hkv, sk, d), rng(hkv * sk * d, 3))?; + let op = FlashSdpaOp { causal, scale: None }; + let bench = |strided: bool| -> f64 { + for _ in 0..2 { + std::hint::black_box(op.flash_attention_gqa_impl( + q.view(), + k.view(), + v.view(), + None, + strided, + )); + } + let mut best = f64::MAX; + for _ in 0..8 { + let t = Instant::now(); + for _ in 0..3 { + std::hint::black_box(op.flash_attention_gqa_impl( + q.view(), + k.view(), + v.view(), + None, + strided, + )); + } + best = best.min(t.elapsed().as_secs_f64() / 3.0); + } + best * 1e3 + }; + let strided = bench(true); + let gemm = bench(false); + println!( + "{label}: strided {strided:7.3} ms -> gemm {gemm:7.3} ms ({:.2}x) [per-layer; x16 layers = {:.1} -> {:.1} ms total SDPA]", + strided / gemm, + strided * 16.0, + gemm * 16.0, + ); + } + Ok(()) + } +} diff --git a/transformers/src/ops/mod.rs b/transformers/src/ops/mod.rs index 2648a01238..a90e1196ae 100644 --- a/transformers/src/ops/mod.rs +++ b/transformers/src/ops/mod.rs @@ -5,6 +5,7 @@ pub mod flash_sdpa; pub mod scaled_masked_softmax; pub mod sdpa; pub mod streamed_sdpa; +pub mod window_kv_cache; // Re-export ops that moved to core pub mod rms_norm { diff --git a/transformers/src/ops/sdpa.rs b/transformers/src/ops/sdpa.rs index 9eee53d686..f743b5575b 100644 --- a/transformers/src/ops/sdpa.rs +++ b/transformers/src/ops/sdpa.rs @@ -249,6 +249,15 @@ impl Sdpa { } } +/// Minimum K/V sequence length at which an f32 `Sdpa` lowers to the (head-parallel) +/// `FlashSdpaOp` on CPU. Shorter sequences fall back to the decomposed matmul+softmax +/// path. Override with `TRACT_FLASH_SDPA_MIN_SEQ_LEN`; default `0` = always flash +/// (head-parallel flash beat the decomposed path at every sequence length measured on +/// Apple M1, 128–4096; raise this on low-core hosts where short-seq decompose wins). +fn flash_min_seq_len() -> usize { + std::env::var("TRACT_FLASH_SDPA_MIN_SEQ_LEN").ok().and_then(|s| s.parse().ok()).unwrap_or(0) +} + impl Op for Sdpa { fn name(&self) -> StaticName { "SDPA".into() @@ -370,12 +379,21 @@ impl TypedOp for Sdpa { // generic SDPA expansion instead. let q_head_dim = model.outlet_fact(node.inputs[0])?.shape.last().cloned(); let v_head_dim = model.outlet_fact(node.inputs[2])?.shape.last().cloned(); - if q_head_dim == v_head_dim { + // Heuristic: very short K/V sequences are faster through the decomposed + // matmul+softmax path (tract's optimized multipliers) than through the + // block-wise flash kernel. The head-parallel FlashSdpaOp wins at longer + // sequences (and always on memory). Threshold is tunable; default 0 keeps + // flash for every sequence — see `flash_min_seq_len`. + let k_fact = model.outlet_fact(node.inputs[1])?; + let kv_len = k_fact.shape.get(k_fact.rank() - 2).and_then(|d| d.to_usize().ok()); + let too_short = kv_len.is_some_and(|n| n < flash_min_seq_len()); + if q_head_dim == v_head_dim && !too_short { let scale = self.scale.as_ref().map(|t| t.cast_to_scalar()).transpose()?; let op = FlashSdpaOp { causal: self.is_causal, scale }; TypedModelPatch::replace_single_op(model, node, &node.inputs, op).map(Some) } else { - self.patch_sdpa(model, node).context("Wiring fallback SDPA (diff head dims)") + self.patch_sdpa(model, node) + .context("Wiring fallback SDPA (short seq / diff head dims)") } } else { self.patch_sdpa(model, node).context("Wiring fallback SDPA") diff --git a/transformers/src/ops/window_kv_cache.rs b/transformers/src/ops/window_kv_cache.rs new file mode 100644 index 0000000000..930ba1c132 --- /dev/null +++ b/transformers/src/ops/window_kv_cache.rs @@ -0,0 +1,656 @@ +//! Bounded (sliding-window) KV cache — a fixed-capacity ring buffer for models +//! trained with sliding-window attention (Mistral, Gemma-style local/global, …). +//! +//! Instead of growing the cache with the sequence (O(T) memory + O(T) per-step +//! attention), keep a fixed `window` slots and overwrite the oldest on append. So a +//! decode of arbitrary length runs at **constant** memory and per-step cost — and it's +//! **lossless**, because the model was trained to attend only within the window. +//! +//! Key trick that makes the ring buffer cheap: **decode attention is order-invariant +//! over keys** — `O = Σ_j softmax_j · V_j` is unchanged if you permute (K,V) together +//! (the set of scores, hence the softmax weights, is identical). So we never have to +//! un-rotate the buffer: the consumer attends over the W physical slots in whatever +//! order they sit, and the result equals attending over the ordered last-W (up to +//! floating-point summation order). Validated in the tests. +//! +//! Companion to the in-place cache (#2321): this is "the in-place cache with a cap and +//! wraparound." For prefill (multi-query) the window also needs a banded causal mask; +//! that lives on the attention op, not here. + +use tract_nnef::internal::*; +use tract_nnef::tract_core::ops::{FrozenOpState, OpStateFreeze}; +use tract_nnef::tract_core::transform::ModelTransform; +use tract_nnef::tract_ndarray::Ix4; + +use crate::ops::dyn_kv_cache::DynKeyValueCache; +use crate::ops::flash_sdpa::FlashSdpaOp; +use crate::ops::sdpa::Sdpa; + +/// NNEF (de)serialization for the fused `WindowKvSdpa` op. +pub fn register(registry: &mut Registry) { + registry.register_dumper(ser_window_kv_sdpa); + registry.register_primitive( + "tract_transformers_window_kv_sdpa", + &[ + TypeName::Scalar.tensor().named("q"), + TypeName::Scalar.tensor().named("k"), + TypeName::Scalar.tensor().named("v"), + TypeName::Integer.named("axis"), + TypeName::Integer.named("window"), + TypeName::Scalar.named("scale"), + ], + &[("output", TypeName::Scalar.tensor())], + de_window_kv_sdpa, + ); +} + +fn ser_window_kv_sdpa( + ast: &mut IntoAst, + node: &TypedNode, + op: &WindowKvSdpa, +) -> TractResult>> { + let q = ast.mapping[&node.inputs[0]].clone(); + let k = ast.mapping[&node.inputs[1]].clone(); + let v = ast.mapping[&node.inputs[2]].clone(); + let mut attrs = vec![("axis", numeric(op.axis)), ("window", numeric(op.window))]; + if let Some(scale) = op.scale { + attrs.push(("scale", numeric(scale))); + } + Ok(Some(invocation("tract_transformers_window_kv_sdpa", &[q, k, v], &attrs))) +} + +fn de_window_kv_sdpa( + builder: &mut ModelBuilder, + invocation: &ResolvedInvocation, +) -> TractResult { + let q = invocation.named_arg_as(builder, "q")?; + let k = invocation.named_arg_as(builder, "k")?; + let v = invocation.named_arg_as(builder, "v")?; + let axis: usize = invocation.named_arg_as(builder, "axis")?; + let window: usize = invocation.named_arg_as(builder, "window")?; + let scale: Option = invocation.get_named_arg_as(builder, "scale")?; + builder.wire(WindowKvSdpa { axis, window, scale }, &[q, k, v]) +} + +/// Fixed-capacity sliding-window KV cache (ring buffer) along `axis`. +#[derive(Clone, Debug)] +pub struct WindowKvCache { + pub axis: usize, + pub window: usize, + buf: Option, // capacity `window` along `axis`, allocated on first push + len: usize, // valid slots, ≤ window + cursor: usize, // next write position, in 0..window +} + +impl WindowKvCache { + pub fn new(axis: usize, window: usize) -> Self { + assert!(window > 0, "window must be > 0"); + WindowKvCache { axis, window, buf: None, len: 0, cursor: 0 } + } + + pub fn len(&self) -> usize { + self.len + } + pub fn is_empty(&self) -> bool { + self.len == 0 + } + /// Always the window capacity once allocated — memory is bounded regardless of T. + pub fn capacity(&self) -> usize { + self.buf.as_ref().map(|b| b.shape()[self.axis]).unwrap_or(0) + } + + fn ensure_buf(&mut self, like: &Tensor) -> TractResult<()> { + if self.buf.is_none() { + let mut shape: TVec = like.shape().into(); + shape[self.axis] = self.window; + self.buf = Some(unsafe { Tensor::uninitialized_dt(like.datum_type(), &shape)? }); + } + Ok(()) + } + + /// Append `input` along `axis`, overwriting the oldest slots once full. O(min(new, window)). + pub fn push(&mut self, input: &Tensor) -> TractResult<()> { + let new = input.shape()[self.axis]; + if new == 0 { + return Ok(()); + } + self.ensure_buf(input)?; + let w = self.window; + + if new >= w { + // Only the last `w` of the input survive; lay them out [0..w], cursor resets. + let buf = self.buf.as_mut().unwrap(); + buf.assign_slice(0..w, input, (new - w)..new, self.axis)?; + self.cursor = 0; + self.len = w; + return Ok(()); + } + + // new < w: write at the cursor, wrapping around the end. + let end = self.cursor + new; + let buf = self.buf.as_mut().unwrap(); + if end <= w { + buf.assign_slice(self.cursor..end, input, 0..new, self.axis)?; + self.cursor = if end == w { 0 } else { end }; + } else { + let first = w - self.cursor; + buf.assign_slice(self.cursor..w, input, 0..first, self.axis)?; + buf.assign_slice(0..(new - first), input, first..new, self.axis)?; + self.cursor = new - first; + } + self.len = (self.len + new).min(w); + Ok(()) + } + + /// Zero-copy view of the `len` valid slots. Once full this is the whole buffer in + /// *physical* (rotated) order — correct for decode attention by order-invariance. + pub fn valid_view(&self) -> TractResult> { + let buf = self.buf.as_ref().context("empty window cache")?; + let mut v = buf.to_plain_array_view::()?; + v.slice_axis_inplace(tract_ndarray::Axis(self.axis), (0..self.len).into()); + Ok(v) + } +} + +/// Fused sliding-window KV-cache + attention (decode). Owns K/V ring buffers of size +/// `window`; each step appends `K_new`/`V_new` and attends `Q` over the (≤window) cache. +/// The bounded cache *is* the sliding window — attending over it equals windowed +/// attention — so decode runs at constant memory + per-step cost, losslessly. Inputs +/// `[Q, K_new, V_new]`, each `[B, H, S, D]`; output has Q's shape. +#[derive(Clone, Debug, PartialEq)] +pub struct WindowKvSdpa { + pub axis: usize, + pub window: usize, + pub scale: Option, +} +impl Eq for WindowKvSdpa {} + +impl Op for WindowKvSdpa { + fn name(&self) -> StaticName { + "WindowKvSdpa".into() + } + fn info(&self) -> TractResult> { + Ok(vec![format!("axis={}, window={}, scale={:?}", self.axis, self.window, self.scale)]) + } + op_as_typed_op!(); +} + +impl EvalOp for WindowKvSdpa { + fn is_stateless(&self) -> bool { + false + } + fn state( + &self, + _session: &TurnState, + _node_id: usize, + ) -> TractResult>> { + Ok(Some(Box::new(WindowKvSdpaState { + window: self.window, + scale: self.scale, + k: WindowKvCache::new(self.axis, self.window), + v: WindowKvCache::new(self.axis, self.window), + }))) + } +} + +impl TypedOp for WindowKvSdpa { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + ensure!(inputs.len() == 3, "WindowKvSdpa expects [Q, K_new, V_new]"); + Ok(tvec!(inputs[0].without_value())) + } + as_op!(); +} + +#[derive(Clone, Debug)] +pub struct WindowKvSdpaState { + window: usize, + scale: Option, + k: WindowKvCache, + v: WindowKvCache, +} + +impl OpState for WindowKvSdpaState { + fn eval( + &mut self, + _state: &mut TurnState, + _op: &dyn Op, + inputs: TVec, + ) -> TractResult> { + ensure!(inputs.len() == 3, "WindowKvSdpa expects [Q, K_new, V_new]"); + let input_dt = inputs[0].datum_type(); + let k_new = inputs[1].cast_to::()?; + let v_new = inputs[2].cast_to::()?; + self.k.push(k_new.as_ref())?; + self.v.push(v_new.as_ref())?; + + let q = inputs[0].cast_to::()?; + let qv = q.to_plain_array_view::()?.into_dimensionality::()?; + let kview = self.k.valid_view::()?.into_dimensionality::()?; + let vview = self.v.valid_view::()?.into_dimensionality::()?; + + // The ring buffer already bounds which keys are visible (the last `window`), so + // attention is "attend all" — every cached key is within the current query's window. + let flash = FlashSdpaOp { causal: false, scale: self.scale }; + let o = flash.flash_attention_gqa(qv, kview, vview, None); + Ok(tvec!(o.into_tensor().cast_to_dt(input_dt)?.into_owned().into_tvalue())) + } +} + +#[derive(Clone, Debug)] +struct FrozenWindowKvSdpaState { + window: usize, + scale: Option, + k: WindowKvCache, + v: WindowKvCache, +} +impl OpStateFreeze for WindowKvSdpaState { + fn freeze(&self) -> Box { + Box::new(FrozenWindowKvSdpaState { + window: self.window, + scale: self.scale, + k: self.k.clone(), + v: self.v.clone(), + }) + } +} +impl FrozenOpState for FrozenWindowKvSdpaState { + fn unfreeze(&self) -> Box { + Box::new(WindowKvSdpaState { + window: self.window, + scale: self.scale, + k: self.k.clone(), + v: self.v.clone(), + }) + } +} + +/// Rewrite rule: fuse `{DynKeyValueCache(K), DynKeyValueCache(V), Sdpa(Q,K,V)}` into a +/// `WindowKvSdpa` with the window supplied via the Rewriter context — so an imported +/// decode model uses a bounded sliding-window cache. The window comes from the model +/// (the GQA `local_window_size` / config), passed to `WindowKvSdpaTransform`. +pub fn fuse_window_kv_sdpa_rule( + window: &usize, + model: &TypedModel, + node: &TypedNode, + node_name: &str, + op: &Sdpa, +) -> TractResult> { + if node.inputs.len() != 3 { + return Ok(None); + } + let k_node = model.node(node.inputs[1].node); + let v_node = model.node(node.inputs[2].node); + let (Some(kc), Some(vc)) = + (k_node.op_as::(), v_node.op_as::()) + else { + return Ok(None); + }; + if kc.axis != vc.axis { + return Ok(None); + } + if k_node.outputs[0].successors.len() != 1 || v_node.outputs[0].successors.len() != 1 { + return Ok(None); + } + let scale = op.scale.as_ref().map(|t| t.cast_to_scalar::()).transpose()?; + let q_outlet = node.inputs[0]; + let k_new = k_node.inputs[0]; + let v_new = v_node.inputs[0]; + + let mut patch = TypedModelPatch::default(); + let taps = patch.taps(model, &[q_outlet, k_new, v_new])?; + let fused = patch.wire_node( + format!("{node_name}.window_kv_sdpa"), + WindowKvSdpa { axis: kc.axis, window: *window, scale }, + &taps, + )?; + patch.shunt_outside(model, node.id.into(), fused[0])?; + Ok(Some(patch)) +} + +/// Strip the GQA broadcast chain, then fuse `cache -> Sdpa` into `WindowKvSdpa` with +/// `window` — making an imported decode model use the bounded sliding-window cache. +#[derive(Debug, Clone)] +pub struct WindowKvSdpaTransform { + pub window: usize, +} + +impl ModelTransform for WindowKvSdpaTransform { + fn name(&self) -> StaticName { + "fuse_window_kv_sdpa".into() + } + fn transform(&self, model: &mut TypedModel) -> TractResult<()> { + Rewriter::default() + .with_rule_for("fuse-kv-broadcast", crate::ops::sdpa::fuse_kv_cache_broadcast_rule) + .rewrite(&(), model)?; + Rewriter::default() + .with_rule_for("fuse-window-kv-sdpa", fuse_window_kv_sdpa_rule) + .rewrite(&self.window, model)?; + model.compact() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tract_nnef::tract_ndarray::{Array4, ArrayView4, s}; + + fn tok(shape: &[usize], v: f32) -> Tensor { + let n: usize = shape.iter().product(); + Tensor::from_shape(shape, &vec![v; n]).unwrap() + } + + // ---- ring-buffer mechanics: holds exactly the last `window` items (as a set) ---- + #[test] + fn window_holds_last_w_as_a_set() -> TractResult<()> { + let w = 4; + let mut c = WindowKvCache::new(2, w); // [B,H,S,D], seq axis + let mut full: Vec = vec![]; + for t in 0..10 { + c.push(&tok(&[1, 1, 1, 1], t as f32))?; + full.push(t as f32); + assert!(c.len() <= w, "len bounded by window"); + // valid_view set == last min(t+1,w) tokens + let view = c.valid_view::()?; + let mut got: Vec = view.iter().copied().collect(); + got.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let mut want: Vec = full.iter().rev().take(w).copied().collect(); + want.sort_by(|a, b| a.partial_cmp(b).unwrap()); + assert_eq!(got, want, "step {t}: window must hold the last {w} as a set"); + } + assert_eq!(c.capacity(), w, "memory stays bounded at the window"); + Ok(()) + } + + #[test] + fn prefill_chunk_larger_than_window_keeps_last_w() -> TractResult<()> { + let w = 3; + let mut c = WindowKvCache::new(2, w); + // one chunk of 7 tokens with distinct values, axis=2 + let chunk = Tensor::from_shape(&[1, 1, 7, 1], &[0f32, 1., 2., 3., 4., 5., 6.])?; + c.push(&chunk)?; + let mut got: Vec = c.valid_view::()?.iter().copied().collect(); + got.sort_by(|a, b| a.partial_cmp(b).unwrap()); + assert_eq!(got, vec![4.0, 5.0, 6.0], "keeps the last 3"); + Ok(()) + } + + // ---- the correctness property: decode attention over the (rotated) window == + // attention over the ordered last-W slice, by order-invariance ---- + fn attention( + q: ArrayView4, + k: ArrayView4, + v: ArrayView4, + scale: f32, + ) -> Array4 { + let (b, h, sq, d) = q.dim(); + let mut out = Array4::::zeros((b, h, sq, d)); + for bi in 0..b { + for hi in 0..h { + let qm = q.slice(s![bi, hi, .., ..]); + let km = k.slice(s![bi, hi, .., ..]); + let vm = v.slice(s![bi, hi, .., ..]); + let mut sc = qm.dot(&km.t()); + sc *= scale; + for mut row in sc.rows_mut() { + let m = row.iter().copied().fold(f32::NEG_INFINITY, f32::max); + let mut s = 0.0; + row.iter_mut().for_each(|x| { + *x = (*x - m).exp(); + s += *x; + }); + row.iter_mut().for_each(|x| *x /= s); + } + out.slice_mut(s![bi, hi, .., ..]).assign(&sc.dot(&vm)); + } + } + out + } + + #[test] + fn windowed_attention_matches_last_w_full() -> TractResult<()> { + let (h, d, w) = (2usize, 8usize, 6usize); + let scale = 1.0 / (d as f32).sqrt(); + // deterministic varied K/V per token so order actually differs after wrap + let seq = |s: usize, base: f32| -> Tensor { + let data: Vec = (0..h * s * d).map(|i| base + (i as f32 * 0.013).sin()).collect(); + Tensor::from_shape(&[1, h, s, d], &data).unwrap() + }; + let mut kc = WindowKvCache::new(2, w); + let mut vc = WindowKvCache::new(2, w); + let mut kfull: Option = None; + let mut vfull: Option = None; + use tract_nnef::tract_core::ops::array::TypedConcat; + for t in 0..20 { + let knew = seq(1, 1.0 + t as f32 * 0.1); + let vnew = seq(1, 5.0 - t as f32 * 0.07); + kc.push(&knew)?; + vc.push(&vnew)?; + let grow = |acc: Option, x: Tensor| -> TractResult { + Ok(match acc { + None => x, + Some(a) => TypedConcat { axis: 2 } + .eval(tvec![a.into(), x.into()])? + .remove(0) + .into_tensor(), + }) + }; + kfull = Some(grow(kfull.take(), knew)?); + vfull = Some(grow(vfull.take(), vnew)?); + + let q = seq(1, 9.0 + t as f32 * 0.05); + let qv = q.to_plain_array_view::()?.into_dimensionality()?; + + // windowed (rotated physical order) + let o_win = attention( + qv, + kc.valid_view::()?.into_dimensionality()?, + vc.valid_view::()?.into_dimensionality()?, + scale, + ); + // reference: ordered last-W of the full cache + let len = kc.len(); + let kf = kfull.as_ref().unwrap(); + let s = kf.shape()[2]; + let kslice = kf.slice(2, s - len, s)?; + let vslice = vfull.as_ref().unwrap().slice(2, s - len, s)?; + let o_ref = attention( + qv, + kslice.to_plain_array_view::()?.into_dimensionality()?, + vslice.to_plain_array_view::()?.into_dimensionality()?, + scale, + ); + let a = Tensor::from(o_win); + let b = Tensor::from(o_ref); + a.close_enough(&b, Approximation::Approximate) + .with_context(|| format!("windowed != last-W at step {t}"))?; + } + Ok(()) + } + + // The fused decode op, run through tract's engine over a long sequence with a small + // window, equals full attention over the last-W each step — i.e. correct sliding-window + // decode, with the cache bounded to `window` regardless of how long we decode. + #[test] + fn window_sdpa_decode_matches_last_w_in_model() -> TractResult<()> { + use tract_nnef::tract_core::ops::array::TypedConcat; + let (b, h, d, w) = (1usize, 2usize, 16usize, 5usize); + let scale = 1.0 / (d as f32).sqrt(); + let mut model = TypedModel::default(); + let s = model.sym("S"); + let dim = |x: usize| x.to_dim(); + let f: TVec = tvec![dim(b), dim(h), s.into(), dim(d)]; + let q = model.add_source("q", f32::fact(&f))?; + let k = model.add_source("k", f32::fact(&f))?; + let v = model.add_source("v", f32::fact(&f))?; + let o = + model.wire_node("win", WindowKvSdpa { axis: 2, window: w, scale: None }, &[q, k, v])?; + model.select_output_outlets(&o)?; + let mut rt = model.into_runnable()?.spawn()?; + + let mk = |base: f32| -> Tensor { + let data: Vec = (0..b * h * d).map(|i| base + (i as f32 * 0.013).sin()).collect(); + Tensor::from_shape(&[b, h, 1, d], &data).unwrap() + }; + let grow = |acc: Option, x: Tensor| -> TractResult { + Ok(match acc { + None => x, + Some(a) => { + TypedConcat { axis: 2 }.eval(tvec![a.into(), x.into()])?.remove(0).into_tensor() + } + }) + }; + let (mut kf, mut vf): (Option, Option) = (None, None); + for t in 0..15 { + let qi = mk(9.0 + t as f32 * 0.1); + let ki = mk(1.0 + t as f32 * 0.07); + let vi = mk(5.0 - t as f32 * 0.05); + let o_model = rt + .run(tvec![qi.clone().into(), ki.clone().into(), vi.clone().into()])? + .remove(0) + .into_tensor(); + kf = Some(grow(kf.take(), ki)?); + vf = Some(grow(vf.take(), vi)?); + let fk = kf.as_ref().unwrap(); + let sk = fk.shape()[2]; + let len = sk.min(w); + let kslice = fk.slice(2, sk - len, sk)?; + let vslice = vf.as_ref().unwrap().slice(2, sk - len, sk)?; + let qv = qi.to_plain_array_view::()?.into_dimensionality()?; + let o_ref = attention( + qv, + kslice.to_plain_array_view::()?.into_dimensionality()?, + vslice.to_plain_array_view::()?.into_dimensionality()?, + scale, + ); + o_model + .close_enough(&Tensor::from(o_ref), Approximation::Approximate) + .with_context(|| format!("window decode != last-{w} at step {t}"))?; + } + Ok(()) + } + + // Auto-wiring: a {DynKeyValueCache(K), DynKeyValueCache(V), Sdpa} decode subgraph is + // fused to WindowKvSdpa{window} by the transform, and the rewritten model then does + // correct windowed decode (vs the un-windowed full attention it had before). + #[test] + fn transform_fuses_cache_sdpa_to_windowed_decode() -> TractResult<()> { + use crate::ops::dyn_kv_cache::DynKeyValueCache; + use crate::ops::sdpa::Sdpa; + use tract_nnef::tract_core::ops::array::TypedConcat; + let (b, h, d, w) = (1usize, 2usize, 16usize, 5usize); + let scale = 1.0 / (d as f32).sqrt(); + let mut model = TypedModel::default(); + let s = model.sym("S"); + let p = model.sym("P"); + let dim = |x: usize| x.to_dim(); + let newf: TVec = tvec![dim(b), dim(h), s.clone().into(), dim(d)]; + let qf: TVec = tvec![dim(b), dim(h), s.into(), dim(d)]; + let pastf: TVec = tvec![dim(b), dim(h), p.into(), dim(d)]; + let q = model.add_source("q", f32::fact(&qf))?; + let knew = model.add_source("k", f32::fact(&newf))?; + let vnew = model.add_source("v", f32::fact(&newf))?; + let mkc = |nm: &str| DynKeyValueCache { + name: nm.to_string(), + axis: 2, + past_sequence_fact: f32::fact(&pastf), + input_sequence_fact: f32::fact(&newf), + }; + let kc = model.wire_node("kc", mkc("kc"), &[knew])?; + let vc = model.wire_node("vc", mkc("vc"), &[vnew])?; + let o = model.wire_node( + "sdpa", + Sdpa { + scale: None, + datum_type: f32::datum_type(), + acc_datum_type: f32::datum_type(), + is_causal: true, + }, + &[q, kc[0], vc[0]], + )?; + model.select_output_outlets(&o)?; + + WindowKvSdpaTransform { window: w }.transform(&mut model)?; + + assert!(model.nodes().iter().any(|n| n.op_is::()), "fused to WindowKvSdpa"); + assert!(!model.nodes().iter().any(|n| n.op_is::()), "caches removed"); + assert!(!model.nodes().iter().any(|n| n.op_is::()), "sdpa removed"); + + let mut rt = model.into_runnable()?.spawn()?; + let mk = |base: f32| -> Tensor { + let data: Vec = (0..b * h * d).map(|i| base + (i as f32 * 0.013).sin()).collect(); + Tensor::from_shape(&[b, h, 1, d], &data).unwrap() + }; + let grow = |acc: Option, x: Tensor| -> TractResult { + Ok(match acc { + None => x, + Some(a) => { + TypedConcat { axis: 2 }.eval(tvec![a.into(), x.into()])?.remove(0).into_tensor() + } + }) + }; + let (mut kf, mut vf): (Option, Option) = (None, None); + for t in 0..15 { + let qi = mk(9.0 + t as f32 * 0.1); + let ki = mk(1.0 + t as f32 * 0.07); + let vi = mk(5.0 - t as f32 * 0.05); + let o_model = rt + .run(tvec![qi.clone().into(), ki.clone().into(), vi.clone().into()])? + .remove(0) + .into_tensor(); + kf = Some(grow(kf.take(), ki)?); + vf = Some(grow(vf.take(), vi)?); + let fk = kf.as_ref().unwrap(); + let sk = fk.shape()[2]; + let len = sk.min(w); + let kslice = fk.slice(2, sk - len, sk)?; + let vslice = vf.as_ref().unwrap().slice(2, sk - len, sk)?; + let qv = qi.to_plain_array_view::()?.into_dimensionality()?; + let o_ref = attention( + qv, + kslice.to_plain_array_view::()?.into_dimensionality()?, + vslice.to_plain_array_view::()?.into_dimensionality()?, + scale, + ); + o_model + .close_enough(&Tensor::from(o_ref), Approximation::Approximate) + .with_context(|| format!("rewritten windowed decode != last-{w} at step {t}"))?; + } + Ok(()) + } + + // NNEF ser/de round-trip: WindowKvSdpa survives write_to_tar -> model_for_read. + #[test] + fn window_kv_sdpa_nnef_round_trip() -> TractResult<()> { + use crate::WithTractTransformers; + let (b, h, d) = (1usize, 2usize, 16usize); + let mut model = TypedModel::default(); + let s = model.sym("S"); + let dim = |x: usize| x.to_dim(); + let f: TVec = tvec![dim(b), dim(h), s.into(), dim(d)]; + let q = model.add_source("q", f32::fact(&f))?; + let k = model.add_source("k", f32::fact(&f))?; + let v = model.add_source("v", f32::fact(&f))?; + let o = model.wire_node( + "win", + WindowKvSdpa { axis: 2, window: 4096, scale: Some(0.125) }, + &[q, k, v], + )?; + model.select_output_outlets(&o)?; + + let nnef = tract_nnef::nnef().with_tract_transformers(); + let mut buffer = vec![]; + nnef.write_to_tar(&model, &mut buffer)?; + let reloaded = nnef.model_for_read(&mut &*buffer)?; + + let n = reloaded + .nodes() + .iter() + .find(|n| n.op_is::()) + .context("WindowKvSdpa survived the round-trip")?; + let op = n.op_as::().unwrap(); + assert_eq!(op.axis, 2); + assert_eq!(op.window, 4096); + assert_eq!(op.scale, Some(0.125)); + Ok(()) + } +}